experimental/io/FsUtil.h \
experimental/Singleton.h \
experimental/TestUtil.h \
- experimental/wangle/channel/AsyncSocketHandler.h \
- experimental/wangle/channel/ChannelHandler.h \
- experimental/wangle/channel/ChannelHandlerContext.h \
- experimental/wangle/channel/ChannelPipeline.h \
- experimental/wangle/channel/OutputBufferingHandler.h \
- experimental/wangle/concurrent/BlockingQueue.h \
- experimental/wangle/concurrent/Codel.h \
- experimental/wangle/concurrent/CPUThreadPoolExecutor.h \
- experimental/wangle/concurrent/FutureExecutor.h \
- experimental/wangle/concurrent/GlobalExecutor.h \
- experimental/wangle/concurrent/IOExecutor.h \
- experimental/wangle/concurrent/IOThreadPoolExecutor.h \
- experimental/wangle/concurrent/LifoSemMPMCQueue.h \
- experimental/wangle/concurrent/NamedThreadFactory.h \
- experimental/wangle/concurrent/ThreadFactory.h \
- experimental/wangle/concurrent/ThreadPoolExecutor.h \
- experimental/wangle/rx/Observable.h \
- experimental/wangle/rx/Observer.h \
- experimental/wangle/rx/Subject.h \
- experimental/wangle/rx/Subscription.h \
- experimental/wangle/rx/types.h \
- experimental/wangle/ConnectionManager.h \
- experimental/wangle/ManagedConnection.h \
- experimental/wangle/acceptor/Acceptor.h \
- experimental/wangle/acceptor/ConnectionCounter.h \
- experimental/wangle/acceptor/SocketOptions.h \
- experimental/wangle/acceptor/DomainNameMisc.h \
- experimental/wangle/acceptor/LoadShedConfiguration.h \
- experimental/wangle/acceptor/NetworkAddress.h \
- experimental/wangle/acceptor/ServerSocketConfig.h \
- experimental/wangle/acceptor/TransportInfo.h \
- experimental/wangle/ssl/ClientHelloExtStats.h \
- experimental/wangle/ssl/DHParam.h \
- experimental/wangle/ssl/PasswordInFile.h \
- experimental/wangle/ssl/SSLCacheOptions.h \
- experimental/wangle/ssl/SSLCacheProvider.h \
- experimental/wangle/ssl/SSLContextConfig.h \
- experimental/wangle/ssl/SSLContextManager.h \
- experimental/wangle/ssl/SSLSessionCacheManager.h \
- experimental/wangle/ssl/SSLStats.h \
- experimental/wangle/ssl/SSLUtil.h \
- experimental/wangle/ssl/TLSTicketKeyManager.h \
- experimental/wangle/ssl/TLSTicketKeySeeds.h \
FBString.h \
FBVector.h \
File.h \
Uri-inl.h \
Varint.h \
VersionCheck.h \
+ wangle/acceptor/Acceptor.h \
+ wangle/acceptor/ConnectionCounter.h \
+ wangle/acceptor/ConnectionManager.h \
+ wangle/acceptor/DomainNameMisc.h \
+ wangle/acceptor/LoadShedConfiguration.h \
+ wangle/acceptor/ManagedConnection.h \
+ wangle/acceptor/NetworkAddress.h \
+ wangle/acceptor/ServerSocketConfig.h \
+ wangle/acceptor/SocketOptions.h \
+ wangle/acceptor/TransportInfo.h \
+ wangle/channel/AsyncSocketHandler.h \
+ wangle/channel/ChannelHandler.h \
+ wangle/channel/ChannelHandlerContext.h \
+ wangle/channel/ChannelPipeline.h \
+ wangle/channel/OutputBufferingHandler.h \
+ wangle/concurrent/BlockingQueue.h \
+ wangle/concurrent/Codel.h \
+ wangle/concurrent/CPUThreadPoolExecutor.h \
+ wangle/concurrent/FutureExecutor.h \
+ wangle/concurrent/IOExecutor.h \
+ wangle/concurrent/IOThreadPoolExecutor.h \
+ wangle/concurrent/GlobalExecutor.h \
+ wangle/concurrent/LifoSemMPMCQueue.h \
+ wangle/concurrent/NamedThreadFactory.h \
+ wangle/concurrent/ThreadFactory.h \
+ wangle/concurrent/ThreadPoolExecutor.h \
wangle/futures/Deprecated.h \
wangle/futures/Future-inl.h \
wangle/futures/Future.h \
wangle/futures/Try.h \
wangle/futures/WangleException.h \
wangle/futures/detail/Core.h \
- wangle/futures/detail/FSM.h
+ wangle/futures/detail/FSM.h \
+ wangle/rx/Observable.h \
+ wangle/rx/Observer.h \
+ wangle/rx/Subject.h \
+ wangle/rx/Subscription.h \
+ wangle/rx/types.h \
+ wangle/ssl/ClientHelloExtStats.h \
+ wangle/ssl/DHParam.h \
+ wangle/ssl/PasswordInFile.h \
+ wangle/ssl/SSLCacheOptions.h \
+ wangle/ssl/SSLCacheProvider.h \
+ wangle/ssl/SSLContextConfig.h \
+ wangle/ssl/SSLContextManager.h \
+ wangle/ssl/SSLSessionCacheManager.h \
+ wangle/ssl/SSLStats.h \
+ wangle/ssl/SSLUtil.h \
+ wangle/ssl/TLSTicketKeyManager.h \
+ wangle/ssl/TLSTicketKeySeeds.h
FormatTables.cpp: build/generate_format_tables.py
build/generate_format_tables.py
TimeoutQueue.cpp \
Uri.cpp \
Version.cpp \
- wangle/futures/InlineExecutor.cpp \
- wangle/futures/ManualExecutor.cpp \
experimental/io/FsUtil.cpp \
experimental/Singleton.cpp \
experimental/TestUtil.cpp \
- experimental/wangle/concurrent/CPUThreadPoolExecutor.cpp \
- experimental/wangle/concurrent/Codel.cpp \
- experimental/wangle/concurrent/GlobalExecutor.cpp \
- experimental/wangle/concurrent/IOExecutor.cpp \
- experimental/wangle/concurrent/IOThreadPoolExecutor.cpp \
- experimental/wangle/concurrent/ThreadPoolExecutor.cpp \
- experimental/wangle/ConnectionManager.cpp \
- experimental/wangle/ManagedConnection.cpp \
- experimental/wangle/acceptor/Acceptor.cpp \
- experimental/wangle/acceptor/SocketOptions.cpp \
- experimental/wangle/acceptor/LoadShedConfiguration.cpp \
- experimental/wangle/acceptor/TransportInfo.cpp \
- experimental/wangle/ssl/PasswordInFile.cpp \
- experimental/wangle/ssl/SSLContextManager.cpp \
- experimental/wangle/ssl/SSLSessionCacheManager.cpp \
- experimental/wangle/ssl/SSLUtil.cpp \
- experimental/wangle/ssl/TLSTicketKeyManager.cpp
+ wangle/acceptor/Acceptor.cpp \
+ wangle/acceptor/ConnectionManager.cpp \
+ wangle/acceptor/LoadShedConfiguration.cpp \
+ wangle/acceptor/ManagedConnection.cpp \
+ wangle/acceptor/SocketOptions.cpp \
+ wangle/acceptor/TransportInfo.cpp \
+ wangle/concurrent/CPUThreadPoolExecutor.cpp \
+ wangle/concurrent/Codel.cpp \
+ wangle/concurrent/IOExecutor.cpp \
+ wangle/concurrent/IOThreadPoolExecutor.cpp \
+ wangle/concurrent/GlobalExecutor.cpp \
+ wangle/concurrent/ThreadPoolExecutor.cpp \
+ wangle/futures/InlineExecutor.cpp \
+ wangle/futures/ManualExecutor.cpp \
+ wangle/ssl/PasswordInFile.cpp \
+ wangle/ssl/SSLContextManager.cpp \
+ wangle/ssl/SSLSessionCacheManager.cpp \
+ wangle/ssl/SSLUtil.cpp \
+ wangle/ssl/TLSTicketKeyManager.cpp
if HAVE_LINUX
nobase_follyinclude_HEADERS += \
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <folly/experimental/wangle/ConnectionManager.h>
-
-#include <glog/logging.h>
-#include <folly/io/async/EventBase.h>
-
-using folly::HHWheelTimer;
-using std::chrono::milliseconds;
-
-namespace folly { namespace wangle {
-
-ConnectionManager::ConnectionManager(EventBase* eventBase,
- milliseconds timeout, Callback* callback)
- : connTimeouts_(new HHWheelTimer(eventBase)),
- callback_(callback),
- eventBase_(eventBase),
- idleIterator_(conns_.end()),
- idleLoopCallback_(this),
- timeout_(timeout) {
-
-}
-
-void
-ConnectionManager::addConnection(ManagedConnection* connection,
- bool timeout) {
- CHECK_NOTNULL(connection);
- ConnectionManager* oldMgr = connection->getConnectionManager();
- if (oldMgr != this) {
- if (oldMgr) {
- // 'connection' was being previously managed in a different thread.
- // We must remove it from that manager before adding it to this one.
- oldMgr->removeConnection(connection);
- }
- conns_.push_back(*connection);
- connection->setConnectionManager(this);
- if (callback_) {
- callback_->onConnectionAdded(*this);
- }
- }
- if (timeout) {
- scheduleTimeout(connection);
- }
-}
-
-void
-ConnectionManager::scheduleTimeout(ManagedConnection* connection) {
- if (timeout_ > std::chrono::milliseconds(0)) {
- connTimeouts_->scheduleTimeout(connection, timeout_);
- }
-}
-
-void ConnectionManager::scheduleTimeout(
- folly::HHWheelTimer::Callback* callback,
- std::chrono::milliseconds timeout) {
- connTimeouts_->scheduleTimeout(callback, timeout);
-}
-
-void
-ConnectionManager::removeConnection(ManagedConnection* connection) {
- if (connection->getConnectionManager() == this) {
- connection->cancelTimeout();
- connection->setConnectionManager(nullptr);
-
- // Un-link the connection from our list, being careful to keep the iterator
- // that we're using for idle shedding valid
- auto it = conns_.iterator_to(*connection);
- if (it == idleIterator_) {
- ++idleIterator_;
- }
- conns_.erase(it);
-
- if (callback_) {
- callback_->onConnectionRemoved(*this);
- if (getNumConnections() == 0) {
- callback_->onEmpty(*this);
- }
- }
- }
-}
-
-void
-ConnectionManager::initiateGracefulShutdown(
- std::chrono::milliseconds idleGrace) {
- if (idleGrace.count() > 0) {
- idleLoopCallback_.scheduleTimeout(idleGrace);
- VLOG(3) << "Scheduling idle grace period of " << idleGrace.count() << "ms";
- } else {
- action_ = ShutdownAction::DRAIN2;
- VLOG(3) << "proceeding directly to closing idle connections";
- }
- drainAllConnections();
-}
-
-void
-ConnectionManager::drainAllConnections() {
- DestructorGuard g(this);
- size_t numCleared = 0;
- size_t numKept = 0;
-
- auto it = idleIterator_ == conns_.end() ?
- conns_.begin() : idleIterator_;
-
- while (it != conns_.end() && (numKept + numCleared) < 64) {
- ManagedConnection& conn = *it++;
- if (action_ == ShutdownAction::DRAIN1) {
- conn.notifyPendingShutdown();
- } else {
- // Second time around: close idle sessions. If they aren't idle yet,
- // have them close when they are idle
- if (conn.isBusy()) {
- numKept++;
- } else {
- numCleared++;
- }
- conn.closeWhenIdle();
- }
- }
-
- if (action_ == ShutdownAction::DRAIN2) {
- VLOG(2) << "Idle connections cleared: " << numCleared <<
- ", busy conns kept: " << numKept;
- }
- if (it != conns_.end()) {
- idleIterator_ = it;
- eventBase_->runInLoop(&idleLoopCallback_);
- } else {
- action_ = ShutdownAction::DRAIN2;
- }
-}
-
-void
-ConnectionManager::dropAllConnections() {
- DestructorGuard g(this);
-
- // Iterate through our connection list, and drop each connection.
- VLOG(3) << "connections to drop: " << conns_.size();
- idleLoopCallback_.cancelTimeout();
- unsigned i = 0;
- while (!conns_.empty()) {
- ManagedConnection& conn = conns_.front();
- conns_.pop_front();
- conn.cancelTimeout();
- conn.setConnectionManager(nullptr);
- // For debugging purposes, dump information about the first few
- // connections.
- static const unsigned MAX_CONNS_TO_DUMP = 2;
- if (++i <= MAX_CONNS_TO_DUMP) {
- conn.dumpConnectionState(3);
- }
- conn.dropConnection();
- }
- idleIterator_ = conns_.end();
- idleLoopCallback_.cancelLoopCallback();
-
- if (callback_) {
- callback_->onEmpty(*this);
- }
-}
-
-}} // folly::wangle
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <folly/experimental/wangle/ManagedConnection.h>
-
-#include <chrono>
-#include <folly/Memory.h>
-#include <folly/io/async/AsyncTimeout.h>
-#include <folly/io/async/HHWheelTimer.h>
-#include <folly/io/async/DelayedDestruction.h>
-#include <folly/io/async/EventBase.h>
-
-namespace folly { namespace wangle {
-
-/**
- * A ConnectionManager keeps track of ManagedConnections.
- */
-class ConnectionManager: public folly::DelayedDestruction {
- public:
-
- /**
- * Interface for an optional observer that's notified about
- * various events in a ConnectionManager
- */
- class Callback {
- public:
- virtual ~Callback() {}
-
- /**
- * Invoked when the number of connections managed by the
- * ConnectionManager changes from nonzero to zero.
- */
- virtual void onEmpty(const ConnectionManager& cm) = 0;
-
- /**
- * Invoked when a connection is added to the ConnectionManager.
- */
- virtual void onConnectionAdded(const ConnectionManager& cm) = 0;
-
- /**
- * Invoked when a connection is removed from the ConnectionManager.
- */
- virtual void onConnectionRemoved(const ConnectionManager& cm) = 0;
- };
-
- typedef std::unique_ptr<ConnectionManager, Destructor> UniquePtr;
-
- /**
- * Returns a new instance of ConnectionManager wrapped in a unique_ptr
- */
- template<typename... Args>
- static UniquePtr makeUnique(Args&&... args) {
- return folly::make_unique<ConnectionManager, Destructor>(
- std::forward<Args>(args)...);
- }
-
- /**
- * Constructor not to be used by itself.
- */
- ConnectionManager(folly::EventBase* eventBase,
- std::chrono::milliseconds timeout,
- Callback* callback = nullptr);
-
- /**
- * Add a connection to the set of connections managed by this
- * ConnectionManager.
- *
- * @param connection The connection to add.
- * @param timeout Whether to immediately register this connection
- * for an idle timeout callback.
- */
- void addConnection(ManagedConnection* connection,
- bool timeout = false);
-
- /**
- * Schedule a timeout callback for a connection.
- */
- void scheduleTimeout(ManagedConnection* connection);
-
- /*
- * Schedule a callback on the wheel timer
- */
- void scheduleTimeout(folly::HHWheelTimer::Callback* callback,
- std::chrono::milliseconds timeout);
-
- /**
- * Remove a connection from this ConnectionManager and, if
- * applicable, cancel the pending timeout callback that the
- * ConnectionManager has scheduled for the connection.
- *
- * @note This method does NOT destroy the connection.
- */
- void removeConnection(ManagedConnection* connection);
-
- /* Begin gracefully shutting down connections in this ConnectionManager.
- * Notify all connections of pending shutdown, and after idleGrace,
- * begin closing idle connections.
- */
- void initiateGracefulShutdown(std::chrono::milliseconds idleGrace);
-
- /**
- * Destroy all connections Managed by this ConnectionManager, even
- * the ones that are busy.
- */
- void dropAllConnections();
-
- size_t getNumConnections() const { return conns_.size(); }
-
- template <typename F>
- void iterateConns(F func) {
- auto it = conns_.begin();
- while ( it != conns_.end()) {
- func(&(*it));
- it++;
- }
- }
-
- private:
- class CloseIdleConnsCallback :
- public folly::EventBase::LoopCallback,
- public folly::AsyncTimeout {
- public:
- explicit CloseIdleConnsCallback(ConnectionManager* manager)
- : folly::AsyncTimeout(manager->eventBase_),
- manager_(manager) {}
-
- void runLoopCallback() noexcept override {
- VLOG(3) << "Draining more conns from loop callback";
- manager_->drainAllConnections();
- }
-
- void timeoutExpired() noexcept override {
- VLOG(3) << "Idle grace expired";
- manager_->drainAllConnections();
- }
-
- private:
- ConnectionManager* manager_;
- };
-
- enum class ShutdownAction : uint8_t {
- /**
- * Drain part 1: inform remote that you will soon reject new requests.
- */
- DRAIN1 = 0,
- /**
- * Drain part 2: start rejecting new requests.
- */
- DRAIN2 = 1,
- };
-
- ~ConnectionManager() {}
-
- ConnectionManager(const ConnectionManager&) = delete;
- ConnectionManager& operator=(ConnectionManager&) = delete;
-
- /**
- * Destroy all connections managed by this ConnectionManager that
- * are currently idle, as determined by a call to each ManagedConnection's
- * isBusy() method.
- */
- void drainAllConnections();
-
- /** All connections */
- folly::CountedIntrusiveList<
- ManagedConnection,&ManagedConnection::listHook_> conns_;
-
- /** Connections that currently are registered for timeouts */
- folly::HHWheelTimer::UniquePtr connTimeouts_;
-
- /** Optional callback to notify of state changes */
- Callback* callback_;
-
- /** Event base in which we run */
- folly::EventBase* eventBase_;
-
- /** Iterator to the next connection to shed; used by drainAllConnections() */
- folly::CountedIntrusiveList<
- ManagedConnection,&ManagedConnection::listHook_>::iterator idleIterator_;
- CloseIdleConnsCallback idleLoopCallback_;
- ShutdownAction action_{ShutdownAction::DRAIN1};
- std::chrono::milliseconds timeout_;
-};
-
-}} // folly::wangle
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <folly/experimental/wangle/ManagedConnection.h>
-
-#include <folly/experimental/wangle/ConnectionManager.h>
-
-namespace folly { namespace wangle {
-
-ManagedConnection::ManagedConnection()
- : connectionManager_(nullptr) {
-}
-
-ManagedConnection::~ManagedConnection() {
- if (connectionManager_) {
- connectionManager_->removeConnection(this);
- }
-}
-
-void
-ManagedConnection::resetTimeout() {
- if (connectionManager_) {
- connectionManager_->scheduleTimeout(this);
- }
-}
-
-void
-ManagedConnection::scheduleTimeout(
- folly::HHWheelTimer::Callback* callback,
- std::chrono::milliseconds timeout) {
- if (connectionManager_) {
- connectionManager_->scheduleTimeout(callback, timeout);
- }
-}
-
-////////////////////// Globals /////////////////////
-
-std::ostream&
-operator<<(std::ostream& os, const ManagedConnection& conn) {
- conn.describe(os);
- return os;
-}
-
-}} // folly::wangle
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <folly/IntrusiveList.h>
-#include <ostream>
-#include <folly/io/async/HHWheelTimer.h>
-#include <folly/io/async/DelayedDestruction.h>
-
-namespace folly { namespace wangle {
-
-class ConnectionManager;
-
-/**
- * Interface describing a connection that can be managed by a
- * container such as an Acceptor.
- */
-class ManagedConnection:
- public folly::HHWheelTimer::Callback,
- public folly::DelayedDestruction {
- public:
-
- ManagedConnection();
-
- // HHWheelTimer::Callback API (left for subclasses to implement).
- virtual void timeoutExpired() noexcept = 0;
-
- /**
- * Print a human-readable description of the connection.
- * @param os Destination stream.
- */
- virtual void describe(std::ostream& os) const = 0;
-
- /**
- * Check whether the connection has any requests outstanding.
- */
- virtual bool isBusy() const = 0;
-
- /**
- * Notify the connection that a shutdown is pending. This method will be
- * called at the beginning of graceful shutdown.
- */
- virtual void notifyPendingShutdown() = 0;
-
- /**
- * Instruct the connection that it should shutdown as soon as it is
- * safe. This is called after notifyPendingShutdown().
- */
- virtual void closeWhenIdle() = 0;
-
- /**
- * Forcibly drop a connection.
- *
- * If a request is in progress, this should cause the connection to be
- * closed with a reset.
- */
- virtual void dropConnection() = 0;
-
- /**
- * Dump the state of the connection to the log
- */
- virtual void dumpConnectionState(uint8_t loglevel) = 0;
-
- /**
- * If the connection has a connection manager, reset the timeout
- * countdown.
- * @note If the connection manager doesn't have the connection scheduled
- * for a timeout already, this method will schedule one. If the
- * connection manager does have the connection connection scheduled
- * for a timeout, this method will push back the timeout to N msec
- * from now, where N is the connection manager's timer interval.
- */
- virtual void resetTimeout();
-
- // Schedule an arbitrary timeout on the HHWheelTimer
- virtual void scheduleTimeout(
- folly::HHWheelTimer::Callback* callback,
- std::chrono::milliseconds timeout);
-
- ConnectionManager* getConnectionManager() {
- return connectionManager_;
- }
-
- protected:
- virtual ~ManagedConnection();
-
- private:
- friend class ConnectionManager;
-
- void setConnectionManager(ConnectionManager* mgr) {
- connectionManager_ = mgr;
- }
-
- ConnectionManager* connectionManager_;
-
- folly::SafeIntrusiveListHook listHook_;
-};
-
-std::ostream& operator<<(std::ostream& os, const ManagedConnection& conn);
-
-}} // folly::wangle
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#include <folly/experimental/wangle/acceptor/Acceptor.h>
-
-#include <folly/experimental/wangle/ManagedConnection.h>
-#include <folly/experimental/wangle/ssl/SSLContextManager.h>
-
-#include <boost/cast.hpp>
-#include <fcntl.h>
-#include <folly/ScopeGuard.h>
-#include <folly/experimental/wangle/ManagedConnection.h>
-#include <folly/io/async/EventBase.h>
-#include <fstream>
-#include <sys/socket.h>
-#include <sys/types.h>
-#include <folly/io/async/AsyncSSLSocket.h>
-#include <folly/io/async/AsyncSocket.h>
-#include <folly/io/async/EventBase.h>
-#include <unistd.h>
-
-using folly::wangle::ConnectionManager;
-using folly::wangle::ManagedConnection;
-using std::chrono::microseconds;
-using std::chrono::milliseconds;
-using std::filebuf;
-using std::ifstream;
-using std::ios;
-using std::shared_ptr;
-using std::string;
-
-namespace folly {
-
-#ifndef NO_LIB_GFLAGS
-DEFINE_int32(shutdown_idle_grace_ms, 5000, "milliseconds to wait before "
- "closing idle conns");
-#else
-const int32_t FLAGS_shutdown_idle_grace_ms = 5000;
-#endif
-
-static const std::string empty_string;
-std::atomic<uint64_t> Acceptor::totalNumPendingSSLConns_{0};
-
-/**
- * Lightweight wrapper class to keep track of a newly
- * accepted connection during SSL handshaking.
- */
-class AcceptorHandshakeHelper :
- public AsyncSSLSocket::HandshakeCB,
- public ManagedConnection {
- public:
- AcceptorHandshakeHelper(AsyncSSLSocket::UniquePtr socket,
- Acceptor* acceptor,
- const SocketAddress& clientAddr,
- std::chrono::steady_clock::time_point acceptTime)
- : socket_(std::move(socket)), acceptor_(acceptor),
- acceptTime_(acceptTime), clientAddr_(clientAddr) {
- acceptor_->downstreamConnectionManager_->addConnection(this, true);
- if(acceptor_->parseClientHello_) {
- socket_->enableClientHelloParsing();
- }
- socket_->sslAccept(this);
- }
-
- virtual void timeoutExpired() noexcept {
- VLOG(4) << "SSL handshake timeout expired";
- sslError_ = SSLErrorEnum::TIMEOUT;
- dropConnection();
- }
- virtual void describe(std::ostream& os) const {
- os << "pending handshake on " << clientAddr_;
- }
- virtual bool isBusy() const {
- return true;
- }
- virtual void notifyPendingShutdown() {}
- virtual void closeWhenIdle() {}
-
- virtual void dropConnection() {
- VLOG(10) << "Dropping in progress handshake for " << clientAddr_;
- socket_->closeNow();
- }
- virtual void dumpConnectionState(uint8_t loglevel) {
- }
-
- private:
- // AsyncSSLSocket::HandshakeCallback API
- virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept {
-
- const unsigned char* nextProto = nullptr;
- unsigned nextProtoLength = 0;
- sock->getSelectedNextProtocol(&nextProto, &nextProtoLength);
- if (VLOG_IS_ON(3)) {
- if (nextProto) {
- VLOG(3) << "Client selected next protocol " <<
- string((const char*)nextProto, nextProtoLength);
- } else {
- VLOG(3) << "Client did not select a next protocol";
- }
- }
-
- // fill in SSL-related fields from TransportInfo
- // the other fields like RTT are filled in the Acceptor
- TransportInfo tinfo;
- tinfo.ssl = true;
- tinfo.acceptTime = acceptTime_;
- tinfo.sslSetupTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
- tinfo.sslSetupBytesRead = sock->getRawBytesReceived();
- tinfo.sslSetupBytesWritten = sock->getRawBytesWritten();
- tinfo.sslServerName = sock->getSSLServerName();
- tinfo.sslCipher = sock->getNegotiatedCipherName();
- tinfo.sslVersion = sock->getSSLVersion();
- tinfo.sslCertSize = sock->getSSLCertSize();
- tinfo.sslResume = SSLUtil::getResumeState(sock);
- sock->getSSLClientCiphers(tinfo.sslClientCiphers);
- sock->getSSLServerCiphers(tinfo.sslServerCiphers);
- tinfo.sslClientComprMethods = sock->getSSLClientComprMethods();
- tinfo.sslClientExts = sock->getSSLClientExts();
- tinfo.sslNextProtocol.assign(
- reinterpret_cast<const char*>(nextProto),
- nextProtoLength);
-
- acceptor_->updateSSLStats(sock, tinfo.sslSetupTime, SSLErrorEnum::NO_ERROR);
- acceptor_->downstreamConnectionManager_->removeConnection(this);
- acceptor_->sslConnectionReady(std::move(socket_), clientAddr_,
- nextProto ? string((const char*)nextProto, nextProtoLength) :
- empty_string, tinfo);
- delete this;
- }
-
- virtual void handshakeErr(AsyncSSLSocket* sock,
- const AsyncSocketException& ex) noexcept {
- auto elapsedTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
- VLOG(3) << "SSL handshake error after " << elapsedTime.count() <<
- " ms; " << sock->getRawBytesReceived() << " bytes received & " <<
- sock->getRawBytesWritten() << " bytes sent: " <<
- ex.what();
- acceptor_->updateSSLStats(sock, elapsedTime, sslError_);
- acceptor_->sslConnectionError();
- delete this;
- }
-
- AsyncSSLSocket::UniquePtr socket_;
- Acceptor* acceptor_;
- std::chrono::steady_clock::time_point acceptTime_;
- SocketAddress clientAddr_;
- SSLErrorEnum sslError_{SSLErrorEnum::NO_ERROR};
-};
-
-Acceptor::Acceptor(const ServerSocketConfig& accConfig) :
- accConfig_(accConfig),
- socketOptions_(accConfig.getSocketOptions()) {
-}
-
-void
-Acceptor::init(AsyncServerSocket* serverSocket,
- EventBase* eventBase) {
- CHECK(nullptr == this->base_);
-
- if (accConfig_.isSSL()) {
- if (!sslCtxManager_) {
- sslCtxManager_ = folly::make_unique<SSLContextManager>(
- eventBase,
- "vip_" + getName(),
- accConfig_.strictSSL, nullptr);
- }
- for (const auto& sslCtxConfig : accConfig_.sslContextConfigs) {
- sslCtxManager_->addSSLContextConfig(
- sslCtxConfig,
- accConfig_.sslCacheOptions,
- &accConfig_.initialTicketSeeds,
- accConfig_.bindAddress,
- cacheProvider_);
- parseClientHello_ |= sslCtxConfig.clientHelloParsingEnabled;
- }
-
- CHECK(sslCtxManager_->getDefaultSSLCtx());
- }
-
- base_ = eventBase;
- state_ = State::kRunning;
- downstreamConnectionManager_ = ConnectionManager::makeUnique(
- eventBase, accConfig_.connectionIdleTimeout, this);
-
- if (serverSocket) {
- serverSocket->addAcceptCallback(this, eventBase);
- // SO_KEEPALIVE is the only setting that is inherited by accepted
- // connections so only apply this setting
- for (const auto& option: socketOptions_) {
- if (option.first.level == SOL_SOCKET &&
- option.first.optname == SO_KEEPALIVE && option.second == 1) {
- serverSocket->setKeepAliveEnabled(true);
- break;
- }
- }
- }
-}
-
-Acceptor::~Acceptor(void) {
-}
-
-void Acceptor::addSSLContextConfig(const SSLContextConfig& sslCtxConfig) {
- sslCtxManager_->addSSLContextConfig(sslCtxConfig,
- accConfig_.sslCacheOptions,
- &accConfig_.initialTicketSeeds,
- accConfig_.bindAddress,
- cacheProvider_);
-}
-
-void
-Acceptor::drainAllConnections() {
- if (downstreamConnectionManager_) {
- downstreamConnectionManager_->initiateGracefulShutdown(
- std::chrono::milliseconds(FLAGS_shutdown_idle_grace_ms));
- }
-}
-
-void Acceptor::setLoadShedConfig(const LoadShedConfiguration& from,
- IConnectionCounter* counter) {
- loadShedConfig_ = from;
- connectionCounter_ = counter;
-}
-
-bool Acceptor::canAccept(const SocketAddress& address) {
- if (!connectionCounter_) {
- return true;
- }
-
- uint64_t maxConnections = connectionCounter_->getMaxConnections();
- if (maxConnections == 0) {
- return true;
- }
-
- uint64_t currentConnections = connectionCounter_->getNumConnections();
- if (currentConnections < maxConnections) {
- return true;
- }
-
- if (loadShedConfig_.isWhitelisted(address)) {
- return true;
- }
-
- // Take care of comparing connection count against max connections across
- // all acceptors. Expensive since a lock must be taken to get the counter.
- auto connectionCountForLoadShedding = getConnectionCountForLoadShedding();
- if (connectionCountForLoadShedding < loadShedConfig_.getMaxConnections()) {
- return true;
- }
-
- VLOG(4) << address.describe() << " not whitelisted";
- return false;
-}
-
-void
-Acceptor::connectionAccepted(
- int fd, const SocketAddress& clientAddr) noexcept {
- if (!canAccept(clientAddr)) {
- close(fd);
- return;
- }
- auto acceptTime = std::chrono::steady_clock::now();
- for (const auto& opt: socketOptions_) {
- opt.first.apply(fd, opt.second);
- }
-
- onDoneAcceptingConnection(fd, clientAddr, acceptTime);
-}
-
-void Acceptor::onDoneAcceptingConnection(
- int fd,
- const SocketAddress& clientAddr,
- std::chrono::steady_clock::time_point acceptTime) noexcept {
- processEstablishedConnection(fd, clientAddr, acceptTime);
-}
-
-void
-Acceptor::processEstablishedConnection(
- int fd,
- const SocketAddress& clientAddr,
- std::chrono::steady_clock::time_point acceptTime) noexcept {
- if (accConfig_.isSSL()) {
- CHECK(sslCtxManager_);
- AsyncSSLSocket::UniquePtr sslSock(
- makeNewAsyncSSLSocket(
- sslCtxManager_->getDefaultSSLCtx(), base_, fd));
- ++numPendingSSLConns_;
- ++totalNumPendingSSLConns_;
- if (totalNumPendingSSLConns_ > accConfig_.maxConcurrentSSLHandshakes) {
- VLOG(2) << "dropped SSL handshake on " << accConfig_.name <<
- " too many handshakes in progress";
- updateSSLStats(sslSock.get(), std::chrono::milliseconds(0),
- SSLErrorEnum::DROPPED);
- sslConnectionError();
- return;
- }
- new AcceptorHandshakeHelper(
- std::move(sslSock), this, clientAddr, acceptTime);
- } else {
- TransportInfo tinfo;
- tinfo.ssl = false;
- tinfo.acceptTime = acceptTime;
- AsyncSocket::UniquePtr sock(makeNewAsyncSocket(base_, fd));
- connectionReady(std::move(sock), clientAddr, empty_string, tinfo);
- }
-}
-
-void
-Acceptor::connectionReady(
- AsyncSocket::UniquePtr sock,
- const SocketAddress& clientAddr,
- const string& nextProtocolName,
- TransportInfo& tinfo) {
- // Limit the number of reads from the socket per poll loop iteration,
- // both to keep memory usage under control and to prevent one fast-
- // writing client from starving other connections.
- sock->setMaxReadsPerEvent(16);
- tinfo.initWithSocket(sock.get());
- onNewConnection(std::move(sock), &clientAddr, nextProtocolName, tinfo);
-}
-
-void
-Acceptor::sslConnectionReady(AsyncSocket::UniquePtr sock,
- const SocketAddress& clientAddr,
- const string& nextProtocol,
- TransportInfo& tinfo) {
- CHECK(numPendingSSLConns_ > 0);
- connectionReady(std::move(sock), clientAddr, nextProtocol, tinfo);
- --numPendingSSLConns_;
- --totalNumPendingSSLConns_;
- if (state_ == State::kDraining) {
- checkDrained();
- }
-}
-
-void
-Acceptor::sslConnectionError() {
- CHECK(numPendingSSLConns_ > 0);
- --numPendingSSLConns_;
- --totalNumPendingSSLConns_;
- if (state_ == State::kDraining) {
- checkDrained();
- }
-}
-
-void
-Acceptor::acceptError(const std::exception& ex) noexcept {
- // An error occurred.
- // The most likely error is out of FDs. AsyncServerSocket will back off
- // briefly if we are out of FDs, then continue accepting later.
- // Just log a message here.
- LOG(ERROR) << "error accepting on acceptor socket: " << ex.what();
-}
-
-void
-Acceptor::acceptStopped() noexcept {
- VLOG(3) << "Acceptor " << this << " acceptStopped()";
- // Drain the open client connections
- drainAllConnections();
-
- // If we haven't yet finished draining, begin doing so by marking ourselves
- // as in the draining state. We must be sure to hit checkDrained() here, as
- // if we're completely idle, we can should consider ourself drained
- // immediately (as there is no outstanding work to complete to cause us to
- // re-evaluate this).
- if (state_ != State::kDone) {
- state_ = State::kDraining;
- checkDrained();
- }
-}
-
-void
-Acceptor::onEmpty(const ConnectionManager& cm) {
- VLOG(3) << "Acceptor=" << this << " onEmpty()";
- if (state_ == State::kDraining) {
- checkDrained();
- }
-}
-
-void
-Acceptor::checkDrained() {
- CHECK(state_ == State::kDraining);
- if (forceShutdownInProgress_ ||
- (downstreamConnectionManager_->getNumConnections() != 0) ||
- (numPendingSSLConns_ != 0)) {
- return;
- }
-
- VLOG(2) << "All connections drained from Acceptor=" << this << " in thread "
- << base_;
-
- downstreamConnectionManager_.reset();
-
- state_ = State::kDone;
-
- onConnectionsDrained();
-}
-
-milliseconds
-Acceptor::getConnTimeout() const {
- return accConfig_.connectionIdleTimeout;
-}
-
-void Acceptor::addConnection(ManagedConnection* conn) {
- // Add the socket to the timeout manager so that it can be cleaned
- // up after being left idle for a long time.
- downstreamConnectionManager_->addConnection(conn, true);
-}
-
-void
-Acceptor::forceStop() {
- base_->runInEventBaseThread([&] { dropAllConnections(); });
-}
-
-void
-Acceptor::dropAllConnections() {
- if (downstreamConnectionManager_) {
- VLOG(3) << "Dropping all connections from Acceptor=" << this <<
- " in thread " << base_;
- assert(base_->isInEventBaseThread());
- forceShutdownInProgress_ = true;
- downstreamConnectionManager_->dropAllConnections();
- CHECK(downstreamConnectionManager_->getNumConnections() == 0);
- downstreamConnectionManager_.reset();
- }
- CHECK(numPendingSSLConns_ == 0);
-
- state_ = State::kDone;
- onConnectionsDrained();
-}
-
-} // namespace
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-#include "folly/experimental/wangle/acceptor/ServerSocketConfig.h"
-#include "folly/experimental/wangle/acceptor/ConnectionCounter.h"
-#include <folly/experimental/wangle/ConnectionManager.h>
-#include "folly/experimental/wangle/acceptor/LoadShedConfiguration.h"
-#include "folly/experimental/wangle/ssl/SSLCacheProvider.h"
-#include "folly/experimental/wangle/acceptor/TransportInfo.h"
-
-#include <chrono>
-#include <event.h>
-#include <folly/io/async/AsyncSSLSocket.h>
-#include <folly/io/async/AsyncServerSocket.h>
-
-namespace folly { namespace wangle {
-class ManagedConnection;
-}}
-
-namespace folly {
-
-class SocketAddress;
-class SSLContext;
-class AsyncTransport;
-class SSLContextManager;
-
-/**
- * An abstract acceptor for TCP-based network services.
- *
- * There is one acceptor object per thread for each listening socket. When a
- * new connection arrives on the listening socket, it is accepted by one of the
- * acceptor objects. From that point on the connection will be processed by
- * that acceptor's thread.
- *
- * The acceptor will call the abstract onNewConnection() method to create
- * a new ManagedConnection object for each accepted socket. The acceptor
- * also tracks all outstanding connections that it has accepted.
- */
-class Acceptor :
- public folly::AsyncServerSocket::AcceptCallback,
- public folly::wangle::ConnectionManager::Callback {
- public:
-
- enum class State : uint32_t {
- kInit, // not yet started
- kRunning, // processing requests normally
- kDraining, // processing outstanding conns, but not accepting new ones
- kDone, // no longer accepting, and all connections finished
- };
-
- explicit Acceptor(const ServerSocketConfig& accConfig);
- virtual ~Acceptor();
-
- /**
- * Supply an SSL cache provider
- * @note Call this before init()
- */
- virtual void setSSLCacheProvider(
- const std::shared_ptr<SSLCacheProvider>& cacheProvider) {
- cacheProvider_ = cacheProvider;
- }
-
- /**
- * Initialize the Acceptor to run in the specified EventBase
- * thread, receiving connections from the specified AsyncServerSocket.
- *
- * This method will be called from the AsyncServerSocket's primary thread,
- * not the specified EventBase thread.
- */
- virtual void init(AsyncServerSocket* serverSocket,
- EventBase* eventBase);
-
- /**
- * Dynamically add a new SSLContextConfig
- */
- void addSSLContextConfig(const SSLContextConfig& sslCtxConfig);
-
- SSLContextManager* getSSLContextManager() const {
- return sslCtxManager_.get();
- }
-
- /**
- * Return the number of outstanding connections in this service instance.
- */
- uint32_t getNumConnections() const {
- return downstreamConnectionManager_ ?
- downstreamConnectionManager_->getNumConnections() : 0;
- }
-
- /**
- * Access the Acceptor's event base.
- */
- virtual EventBase* getEventBase() const { return base_; }
-
- /**
- * Access the Acceptor's downstream (client-side) ConnectionManager
- */
- virtual folly::wangle::ConnectionManager* getConnectionManager() {
- return downstreamConnectionManager_.get();
- }
-
- /**
- * Invoked when a new ManagedConnection is created.
- *
- * This allows the Acceptor to track the outstanding connections,
- * for tracking timeouts and for ensuring that all connections have been
- * drained on shutdown.
- */
- void addConnection(folly::wangle::ManagedConnection* connection);
-
- /**
- * Get this acceptor's current state.
- */
- State getState() const {
- return state_;
- }
-
- /**
- * Get the current connection timeout.
- */
- std::chrono::milliseconds getConnTimeout() const;
-
- /**
- * Returns the name of this VIP.
- *
- * Will return an empty string if no name has been configured.
- */
- const std::string& getName() const {
- return accConfig_.name;
- }
-
- /**
- * Force the acceptor to drop all connections and stop processing.
- *
- * This function may be called from any thread. The acceptor will not
- * necessarily stop before this function returns: the stop will be scheduled
- * to run in the acceptor's thread.
- */
- virtual void forceStop();
-
- bool isSSL() const { return accConfig_.isSSL(); }
-
- const ServerSocketConfig& getConfig() const { return accConfig_; }
-
- static uint64_t getTotalNumPendingSSLConns() {
- return totalNumPendingSSLConns_.load();
- }
-
- /**
- * Called right when the TCP connection has been accepted, before processing
- * the first HTTP bytes (HTTP) or the SSL handshake (HTTPS)
- */
- virtual void onDoneAcceptingConnection(
- int fd,
- const SocketAddress& clientAddr,
- std::chrono::steady_clock::time_point acceptTime
- ) noexcept;
-
- /**
- * Begins either processing HTTP bytes (HTTP) or the SSL handshake (HTTPS)
- */
- void processEstablishedConnection(
- int fd,
- const SocketAddress& clientAddr,
- std::chrono::steady_clock::time_point acceptTime
- ) noexcept;
-
- /**
- * Drains all open connections of their outstanding transactions. When
- * a connection's transaction count reaches zero, the connection closes.
- */
- void drainAllConnections();
-
- protected:
- friend class AcceptorHandshakeHelper;
-
- /**
- * Our event loop.
- *
- * Probably needs to be used to pass to a ManagedConnection
- * implementation. Also visible in case a subclass wishes to do additional
- * things w/ the event loop (e.g. in attach()).
- */
- EventBase* base_{nullptr};
-
- virtual uint64_t getConnectionCountForLoadShedding(void) const { return 0; }
-
- /**
- * Hook for subclasses to drop newly accepted connections prior
- * to handshaking.
- */
- virtual bool canAccept(const folly::SocketAddress&);
-
- /**
- * Invoked when a new connection is created. This is where application starts
- * processing a new downstream connection.
- *
- * NOTE: Application should add the new connection to
- * downstreamConnectionManager so that it can be garbage collected after
- * certain period of idleness.
- *
- * @param sock the socket connected to the client
- * @param address the address of the client
- * @param nextProtocolName the name of the L6 or L7 protocol to be
- * spoken on the connection, if known (e.g.,
- * from TLS NPN during secure connection setup),
- * or an empty string if unknown
- */
- virtual void onNewConnection(
- AsyncSocket::UniquePtr sock,
- const folly::SocketAddress* address,
- const std::string& nextProtocolName,
- const TransportInfo& tinfo) = 0;
-
- virtual AsyncSocket::UniquePtr makeNewAsyncSocket(EventBase* base, int fd) {
- return AsyncSocket::UniquePtr(new AsyncSocket(base, fd));
- }
-
- virtual AsyncSSLSocket::UniquePtr makeNewAsyncSSLSocket(
- const std::shared_ptr<SSLContext>& ctx, EventBase* base, int fd) {
- return AsyncSSLSocket::UniquePtr(new AsyncSSLSocket(ctx, base, fd));
- }
-
- /**
- * Hook for subclasses to record stats about SSL connection establishment.
- */
- virtual void updateSSLStats(
- const AsyncSSLSocket* sock,
- std::chrono::milliseconds acceptLatency,
- SSLErrorEnum error) noexcept {}
-
- /**
- * Drop all connections.
- *
- * forceStop() schedules dropAllConnections() to be called in the acceptor's
- * thread.
- */
- void dropAllConnections();
-
- protected:
-
- /**
- * onConnectionsDrained() will be called once all connections have been
- * drained while the acceptor is stopping.
- *
- * Subclasses can override this method to perform any subclass-specific
- * cleanup.
- */
- virtual void onConnectionsDrained() {}
-
- // AsyncServerSocket::AcceptCallback methods
- void connectionAccepted(int fd,
- const folly::SocketAddress& clientAddr)
- noexcept;
- void acceptError(const std::exception& ex) noexcept;
- void acceptStopped() noexcept;
-
- // ConnectionManager::Callback methods
- void onEmpty(const folly::wangle::ConnectionManager& cm);
- void onConnectionAdded(const folly::wangle::ConnectionManager& cm) {}
- void onConnectionRemoved(const folly::wangle::ConnectionManager& cm) {}
-
- /**
- * Process a connection that is to ready to receive L7 traffic.
- * This method is called immediately upon accept for plaintext
- * connections and upon completion of SSL handshaking or resumption
- * for SSL connections.
- */
- void connectionReady(
- AsyncSocket::UniquePtr sock,
- const folly::SocketAddress& clientAddr,
- const std::string& nextProtocolName,
- TransportInfo& tinfo);
-
- const LoadShedConfiguration& getLoadShedConfiguration() const {
- return loadShedConfig_;
- }
-
- protected:
- const ServerSocketConfig accConfig_;
- void setLoadShedConfig(const LoadShedConfiguration& from,
- IConnectionCounter* counter);
-
- /**
- * Socket options to apply to the client socket
- */
- AsyncSocket::OptionMap socketOptions_;
-
- std::unique_ptr<SSLContextManager> sslCtxManager_;
-
- /**
- * Whether we want to enable client hello parsing in the handshake helper
- * to get list of supported client ciphers.
- */
- bool parseClientHello_{false};
-
- folly::wangle::ConnectionManager::UniquePtr downstreamConnectionManager_;
-
- private:
-
- // Forbidden copy constructor and assignment opererator
- Acceptor(Acceptor const &) = delete;
- Acceptor& operator=(Acceptor const &) = delete;
-
- /**
- * Wrapper for connectionReady() that decrements the count of
- * pending SSL connections.
- */
- void sslConnectionReady(AsyncSocket::UniquePtr sock,
- const folly::SocketAddress& clientAddr,
- const std::string& nextProtocol,
- TransportInfo& tinfo);
-
- /**
- * Notification callback for SSL handshake failures.
- */
- void sslConnectionError();
-
- void checkDrained();
-
- State state_{State::kInit};
- uint64_t numPendingSSLConns_{0};
-
- static std::atomic<uint64_t> totalNumPendingSSLConns_;
-
- bool forceShutdownInProgress_{false};
- LoadShedConfiguration loadShedConfig_;
- IConnectionCounter* connectionCounter_{nullptr};
- std::shared_ptr<SSLCacheProvider> cacheProvider_;
-};
-
-class AcceptorFactory {
- public:
- virtual std::shared_ptr<Acceptor> newAcceptor() = 0;
- virtual ~AcceptorFactory() = default;
-};
-
-} // namespace
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-namespace folly {
-
-class IConnectionCounter {
- public:
- virtual uint64_t getNumConnections() const = 0;
-
- /**
- * Get the maximum number of non-whitelisted client-side connections
- * across all Acceptors managed by this. A value
- * of zero means "unlimited."
- */
- virtual uint64_t getMaxConnections() const = 0;
-
- /**
- * Increment the count of client-side connections.
- */
- virtual void onConnectionAdded() = 0;
-
- /**
- * Decrement the count of client-side connections.
- */
- virtual void onConnectionRemoved() = 0;
- virtual ~IConnectionCounter() {}
-};
-
-class SimpleConnectionCounter: public IConnectionCounter {
- public:
- uint64_t getNumConnections() const override { return numConnections_; }
- uint64_t getMaxConnections() const override { return maxConnections_; }
- void setMaxConnections(uint64_t maxConnections) {
- maxConnections_ = maxConnections;
- }
-
- void onConnectionAdded() override { numConnections_++; }
- void onConnectionRemoved() override { numConnections_--; }
- virtual ~SimpleConnectionCounter() {}
-
- protected:
- uint64_t maxConnections_{0};
- uint64_t numConnections_{0};
-};
-
-}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-#include <string>
-
-namespace folly {
-
-struct dn_char_traits : public std::char_traits<char> {
- static bool eq(char c1, char c2) {
- return ::tolower(c1) == ::tolower(c2);
- }
-
- static bool ne(char c1, char c2) {
- return ::tolower(c1) != ::tolower(c2);
- }
-
- static bool lt(char c1, char c2) {
- return ::tolower(c1) < ::tolower(c2);
- }
-
- static int compare(const char* s1, const char* s2, size_t n) {
- while (n--) {
- if(::tolower(*s1) < ::tolower(*s2) ) {
- return -1;
- }
- if(::tolower(*s1) > ::tolower(*s2) ) {
- return 1;
- }
- ++s1;
- ++s2;
- }
- return 0;
- }
-
- static const char* find(const char* s, size_t n, char a) {
- char la = ::tolower(a);
- while (n--) {
- if(::tolower(*s) == la) {
- return s;
- } else {
- ++s;
- }
- }
- return nullptr;
- }
-};
-
-// Case insensitive string
-typedef std::basic_string<char, dn_char_traits> DNString;
-
-struct DNStringHash : public std::hash<std::string> {
- size_t operator()(const DNString& s) const noexcept {
- size_t h = static_cast<size_t>(0xc70f6907UL);
- const char* d = s.data();
- for (size_t i = 0; i < s.length(); ++i) {
- char a = ::tolower(*d++);
- h = std::_Hash_impl::hash(&a, sizeof(a), h);
- }
- return h;
- }
-};
-
-} // namespace
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#include <folly/experimental/wangle/acceptor/LoadShedConfiguration.h>
-
-#include <folly/Conv.h>
-#include <openssl/ssl.h>
-
-using std::string;
-
-namespace folly {
-
-void LoadShedConfiguration::addWhitelistAddr(folly::StringPiece input) {
- auto addr = input.str();
- size_t separator = addr.find_first_of('/');
- if (separator == string::npos) {
- whitelistAddrs_.insert(SocketAddress(addr, 0));
- } else {
- unsigned prefixLen = folly::to<unsigned>(addr.substr(separator + 1));
- addr.erase(separator);
- whitelistNetworks_.insert(NetworkAddress(SocketAddress(addr, 0), prefixLen));
- }
-}
-
-bool LoadShedConfiguration::isWhitelisted(const SocketAddress& address) const {
- if (whitelistAddrs_.find(address) != whitelistAddrs_.end()) {
- return true;
- }
- for (auto& network : whitelistNetworks_) {
- if (network.contains(address)) {
- return true;
- }
- }
- return false;
-}
-
-}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-#include <chrono>
-#include <folly/Range.h>
-#include <folly/SocketAddress.h>
-#include <glog/logging.h>
-#include <list>
-#include <set>
-#include <string>
-
-#include <folly/experimental/wangle/acceptor/NetworkAddress.h>
-
-namespace folly {
-
-/**
- * Class that holds an LoadShed configuration for a service
- */
-class LoadShedConfiguration {
- public:
-
- // Comparison function for SocketAddress that disregards the port
- struct AddressOnlyCompare {
- bool operator()(
- const SocketAddress& addr1,
- const SocketAddress& addr2) const {
- return addr1.getIPAddress() < addr2.getIPAddress();
- }
- };
-
- typedef std::set<SocketAddress, AddressOnlyCompare> AddressSet;
- typedef std::set<NetworkAddress> NetworkSet;
-
- LoadShedConfiguration() {}
-
- ~LoadShedConfiguration() {}
-
- void addWhitelistAddr(folly::StringPiece);
-
- /**
- * Set/get the set of IPs that should be whitelisted through even when we're
- * trying to shed load.
- */
- void setWhitelistAddrs(const AddressSet& addrs) { whitelistAddrs_ = addrs; }
- const AddressSet& getWhitelistAddrs() const { return whitelistAddrs_; }
-
- /**
- * Set/get the set of networks that should be whitelisted through even
- * when we're trying to shed load.
- */
- void setWhitelistNetworks(const NetworkSet& networks) {
- whitelistNetworks_ = networks;
- }
- const NetworkSet& getWhitelistNetworks() const { return whitelistNetworks_; }
-
- /**
- * Set/get the maximum number of downstream connections across all VIPs.
- */
- void setMaxConnections(uint64_t maxConns) { maxConnections_ = maxConns; }
- uint64_t getMaxConnections() const { return maxConnections_; }
-
- /**
- * Set/get the maximum cpu usage.
- */
- void setMaxMemUsage(double max) {
- CHECK(max >= 0);
- CHECK(max <= 1);
- maxMemUsage_ = max;
- }
- double getMaxMemUsage() const { return maxMemUsage_; }
-
- /**
- * Set/get the maximum memory usage.
- */
- void setMaxCpuUsage(double max) {
- CHECK(max >= 0);
- CHECK(max <= 1);
- maxCpuUsage_ = max;
- }
- double getMaxCpuUsage() const { return maxCpuUsage_; }
-
- void setLoadUpdatePeriod(std::chrono::milliseconds period) {
- period_ = period;
- }
- std::chrono::milliseconds getLoadUpdatePeriod() const { return period_; }
-
- bool isWhitelisted(const SocketAddress& addr) const;
-
- private:
-
- AddressSet whitelistAddrs_;
- NetworkSet whitelistNetworks_;
- uint64_t maxConnections_{0};
- double maxMemUsage_;
- double maxCpuUsage_;
- std::chrono::milliseconds period_;
-};
-
-}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-#include <folly/SocketAddress.h>
-
-namespace folly {
-
-/**
- * A simple wrapper around SocketAddress that represents
- * a network in CIDR notation
- */
-class NetworkAddress {
-public:
- /**
- * Create a NetworkAddress for an addr/prefixLen
- * @param addr IPv4 or IPv6 address of the network
- * @param prefixLen Prefix length, in bits
- */
- NetworkAddress(const folly::SocketAddress& addr,
- unsigned prefixLen):
- addr_(addr), prefixLen_(prefixLen) {}
-
- /** Get the network address */
- const folly::SocketAddress& getAddress() const {
- return addr_;
- }
-
- /** Get the prefix length in bits */
- unsigned getPrefixLength() const { return prefixLen_; }
-
- /** Check whether a given address lies within the network */
- bool contains(const folly::SocketAddress& addr) const {
- return addr_.prefixMatch(addr, prefixLen_);
- }
-
- /** Comparison operator to enable use in ordered collections */
- bool operator<(const NetworkAddress& other) const {
- if (addr_ < other.addr_) {
- return true;
- } else if (other.addr_ < addr_) {
- return false;
- } else {
- return (prefixLen_ < other.prefixLen_);
- }
- }
-
-private:
- folly::SocketAddress addr_;
- unsigned prefixLen_;
-};
-
-} // namespace
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-#include <folly/experimental/wangle/ssl/SSLCacheOptions.h>
-#include <folly/experimental/wangle/ssl/SSLContextConfig.h>
-#include <folly/experimental/wangle/ssl/TLSTicketKeySeeds.h>
-#include <folly/experimental/wangle/ssl/SSLUtil.h>
-#include <folly/experimental/wangle/acceptor/SocketOptions.h>
-
-#include <boost/optional.hpp>
-#include <chrono>
-#include <fcntl.h>
-#include <folly/Random.h>
-#include <folly/SocketAddress.h>
-#include <folly/String.h>
-#include <folly/io/async/SSLContext.h>
-#include <list>
-#include <string>
-#include <sys/stat.h>
-#include <sys/types.h>
-#include <folly/io/async/AsyncSocket.h>
-#include <folly/io/async/SSLContext.h>
-#include <folly/SocketAddress.h>
-
-namespace folly {
-
-/**
- * Configuration for a single Acceptor.
- *
- * This configures not only accept behavior, but also some types of SSL
- * behavior that may make sense to configure on a per-VIP basis (e.g. which
- * cert(s) we use, etc).
- */
-struct ServerSocketConfig {
- ServerSocketConfig() {
- // generate a single random current seed
- uint8_t seed[32];
- folly::Random::secureRandom(seed, sizeof(seed));
- initialTicketSeeds.currentSeeds.push_back(
- SSLUtil::hexlify(std::string((char *)seed, sizeof(seed))));
- }
-
- bool isSSL() const { return !(sslContextConfigs.empty()); }
-
- /**
- * Set/get the socket options to apply on all downstream connections.
- */
- void setSocketOptions(
- const AsyncSocket::OptionMap& opts) {
- socketOptions_ = filterIPSocketOptions(opts, bindAddress.getFamily());
- }
- AsyncSocket::OptionMap&
- getSocketOptions() {
- return socketOptions_;
- }
- const AsyncSocket::OptionMap&
- getSocketOptions() const {
- return socketOptions_;
- }
-
- bool hasExternalPrivateKey() const {
- for (const auto& cfg : sslContextConfigs) {
- if (!cfg.isLocalPrivateKey) {
- return true;
- }
- }
- return false;
- }
-
- /**
- * The name of this acceptor; used for stats/reporting purposes.
- */
- std::string name;
-
- /**
- * The depth of the accept queue backlog.
- */
- uint32_t acceptBacklog{1024};
-
- /**
- * The number of milliseconds a connection can be idle before we close it.
- */
- std::chrono::milliseconds connectionIdleTimeout{600000};
-
- /**
- * The address to bind to.
- */
- SocketAddress bindAddress;
-
- /**
- * Options for controlling the SSL cache.
- */
- SSLCacheOptions sslCacheOptions{std::chrono::seconds(600), 20480, 200};
-
- /**
- * The initial TLS ticket seeds.
- */
- TLSTicketKeySeeds initialTicketSeeds;
-
- /**
- * The configs for all the SSL_CTX for use by this Acceptor.
- */
- std::vector<SSLContextConfig> sslContextConfigs;
-
- /**
- * Determines if the Acceptor does strict checking when loading the SSL
- * contexts.
- */
- bool strictSSL{true};
-
- /**
- * Maximum number of concurrent pending SSL handshakes
- */
- uint32_t maxConcurrentSSLHandshakes{30720};
-
- private:
- AsyncSocket::OptionMap socketOptions_;
-};
-
-} // folly
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#include <folly/experimental/wangle/acceptor/SocketOptions.h>
-
-#include <netinet/tcp.h>
-#include <sys/socket.h>
-
-namespace folly {
-
-AsyncSocket::OptionMap filterIPSocketOptions(
- const AsyncSocket::OptionMap& allOptions,
- const int addrFamily) {
- AsyncSocket::OptionMap opts;
- int exclude;
- if (addrFamily == AF_INET) {
- exclude = IPPROTO_IPV6;
- } else if (addrFamily == AF_INET6) {
- exclude = IPPROTO_IP;
- } else {
- LOG(FATAL) << "Address family " << addrFamily << " was not IPv4 or IPv6";
- return opts;
- }
- for (const auto& opt: allOptions) {
- if (opt.first.level != exclude) {
- opts[opt.first] = opt.second;
- }
- }
- return opts;
-}
-
-}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-#include <folly/io/async/AsyncSocket.h>
-
-namespace folly {
-
-/**
- * Returns a copy of the socket options excluding options with the given
- * level.
- */
-AsyncSocket::OptionMap filterIPSocketOptions(
- const AsyncSocket::OptionMap& allOptions,
- const int addrFamily);
-
-}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#include <folly/experimental/wangle/acceptor/TransportInfo.h>
-
-#include <sys/socket.h>
-#include <sys/types.h>
-#include <folly/io/async/AsyncSocket.h>
-
-using std::chrono::microseconds;
-using std::map;
-using std::string;
-
-namespace folly {
-
-bool TransportInfo::initWithSocket(const AsyncSocket* sock) {
-#if defined(__linux__) || defined(__FreeBSD__)
- if (!TransportInfo::readTcpInfo(&tcpinfo, sock)) {
- tcpinfoErrno = errno;
- return false;
- }
- rtt = microseconds(tcpinfo.tcpi_rtt);
- validTcpinfo = true;
-#else
- tcpinfoErrno = EINVAL;
- rtt = microseconds(-1);
-#endif
- return true;
-}
-
-int64_t TransportInfo::readRTT(const AsyncSocket* sock) {
-#if defined(__linux__) || defined(__FreeBSD__)
- struct tcp_info tcpinfo;
- if (!TransportInfo::readTcpInfo(&tcpinfo, sock)) {
- return -1;
- }
- return tcpinfo.tcpi_rtt;
-#else
- return -1;
-#endif
-}
-
-#if defined(__linux__) || defined(__FreeBSD__)
-bool TransportInfo::readTcpInfo(struct tcp_info* tcpinfo,
- const AsyncSocket* sock) {
- socklen_t len = sizeof(struct tcp_info);
- if (!sock) {
- return false;
- }
- if (getsockopt(sock->getFd(), IPPROTO_TCP,
- TCP_INFO, (void*) tcpinfo, &len) < 0) {
- VLOG(4) << "Error calling getsockopt(): " << strerror(errno);
- return false;
- }
- return true;
-}
-#endif
-
-} // folly
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-#include <folly/experimental/wangle/ssl/SSLUtil.h>
-
-#include <chrono>
-#include <netinet/tcp.h>
-#include <string>
-
-namespace folly {
-class AsyncSocket;
-
-/**
- * A structure that encapsulates byte counters related to the HTTP headers.
- */
-struct HTTPHeaderSize {
- /**
- * The number of bytes used to represent the header after compression or
- * before decompression. If header compression is not supported, the value
- * is set to 0.
- */
- size_t compressed{0};
-
- /**
- * The number of bytes used to represent the serialized header before
- * compression or after decompression, in plain-text format.
- */
- size_t uncompressed{0};
-};
-
-struct TransportInfo {
- /*
- * timestamp of when the connection handshake was completed
- */
- std::chrono::steady_clock::time_point acceptTime{};
-
- /*
- * connection RTT (Round-Trip Time)
- */
- std::chrono::microseconds rtt{0};
-
-#if defined(__linux__) || defined(__FreeBSD__)
- /*
- * TCP information as fetched from getsockopt(2)
- */
- tcp_info tcpinfo {
-#if __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 17
- 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 // 32
-#else
- 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 // 29
-#endif // __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 17
- };
-#endif // defined(__linux__) || defined(__FreeBSD__)
-
- /*
- * time for setting the connection, from the moment in was accepted until it
- * is established.
- */
- std::chrono::milliseconds setupTime{0};
-
- /*
- * time for setting up the SSL connection or SSL handshake
- */
- std::chrono::milliseconds sslSetupTime{0};
-
- /*
- * The name of the SSL ciphersuite used by the transaction's
- * transport. Returns null if the transport is not SSL.
- */
- const char* sslCipher{nullptr};
-
- /*
- * The SSL server name used by the transaction's
- * transport. Returns null if the transport is not SSL.
- */
- const char* sslServerName{nullptr};
-
- /*
- * list of ciphers sent by the client
- */
- std::string sslClientCiphers{};
-
- /*
- * list of compression methods sent by the client
- */
- std::string sslClientComprMethods{};
-
- /*
- * list of TLS extensions sent by the client
- */
- std::string sslClientExts{};
-
- /*
- * hash of all the SSL parameters sent by the client
- */
- std::string sslSignature{};
-
- /*
- * list of ciphers supported by the server
- */
- std::string sslServerCiphers{};
-
- /*
- * guessed "(os) (browser)" based on SSL Signature
- */
- std::string guessedUserAgent{};
-
- /**
- * The result of SSL NPN negotiation.
- */
- std::string sslNextProtocol{};
-
- /*
- * total number of bytes sent over the connection
- */
- int64_t totalBytes{0};
-
- /**
- * header bytes read
- */
- HTTPHeaderSize ingressHeader;
-
- /*
- * header bytes written
- */
- HTTPHeaderSize egressHeader;
-
- /*
- * Here is how the timeToXXXByte variables are planned out:
- * 1. All timeToXXXByte variables are measuring the ByteEvent from reqStart_
- * 2. You can get the timing between two ByteEvents by calculating their
- * differences. For example:
- * timeToLastBodyByteAck - timeToFirstByte
- * => Total time to deliver the body
- * 3. The calculation in point (2) is typically done outside acceptor
- *
- * Future plan:
- * We should log the timestamps (TimePoints) and allow
- * the consumer to calculate the latency whatever it
- * wants instead of calculating them in wangle, for the sake of flexibility.
- * For example:
- * 1. TimePoint reqStartTimestamp;
- * 2. TimePoint firstHeaderByteSentTimestamp;
- * 3. TimePoint firstBodyByteTimestamp;
- * 3. TimePoint lastBodyByteTimestamp;
- * 4. TimePoint lastBodyByteAckTimestamp;
- */
-
- /*
- * time to first header byte written to the kernel send buffer
- * NOTE: It is not 100% accurate since TAsyncSocket does not do
- * do callback on partial write.
- */
- int32_t timeToFirstHeaderByte{-1};
-
- /*
- * time to first body byte written to the kernel send buffer
- */
- int32_t timeToFirstByte{-1};
-
- /*
- * time to last body byte written to the kernel send buffer
- */
- int32_t timeToLastByte{-1};
-
- /*
- * time to TCP Ack received for the last written body byte
- */
- int32_t timeToLastBodyByteAck{-1};
-
- /*
- * time it took the client to ACK the last byte, from the moment when the
- * kernel sent the last byte to the client and until it received the ACK
- * for that byte
- */
- int32_t lastByteAckLatency{-1};
-
- /*
- * time spent inside wangle
- */
- int32_t proxyLatency{-1};
-
- /*
- * time between connection accepted and client message headers completed
- */
- int32_t clientLatency{-1};
-
- /*
- * latency for communication with the server
- */
- int32_t serverLatency{-1};
-
- /*
- * time used to get a usable connection.
- */
- int32_t connectLatency{-1};
-
- /*
- * body bytes written
- */
- uint32_t egressBodySize{0};
-
- /*
- * value of errno in case of getsockopt() error
- */
- int tcpinfoErrno{0};
-
- /*
- * bytes read & written during SSL Setup
- */
- uint32_t sslSetupBytesWritten{0};
- uint32_t sslSetupBytesRead{0};
-
- /**
- * SSL error detail
- */
- uint32_t sslError{0};
-
- /**
- * body bytes read
- */
- uint32_t ingressBodySize{0};
-
- /*
- * The SSL version used by the transaction's transport, in
- * OpenSSL's format: 4 bits for the major version, followed by 4 bits
- * for the minor version. Returns zero for non-SSL.
- */
- uint16_t sslVersion{0};
-
- /*
- * The SSL certificate size.
- */
- uint16_t sslCertSize{0};
-
- /**
- * response status code
- */
- uint16_t statusCode{0};
-
- /*
- * The SSL mode for the transaction's transport: new session,
- * resumed session, or neither (non-SSL).
- */
- SSLResumeEnum sslResume{SSLResumeEnum::NA};
-
- /*
- * true if the tcpinfo was successfully read from the kernel
- */
- bool validTcpinfo{false};
-
- /*
- * true if the connection is SSL, false otherwise
- */
- bool ssl{false};
-
- /*
- * get the RTT value in milliseconds
- */
- std::chrono::milliseconds getRttMs() const {
- return std::chrono::duration_cast<std::chrono::milliseconds>(rtt);
- }
-
- /*
- * initialize the fields related with tcp_info
- */
- bool initWithSocket(const AsyncSocket* sock);
-
- /*
- * Get the kernel's estimate of round-trip time (RTT) to the transport's peer
- * in microseconds. Returns -1 on error.
- */
- static int64_t readRTT(const AsyncSocket* sock);
-
-#if defined(__linux__) || defined(__FreeBSD__)
- /*
- * perform the getsockopt(2) syscall to fetch TCP info for a given socket
- */
- static bool readTcpInfo(struct tcp_info* tcpinfo,
- const AsyncSocket* sock);
-#endif
-};
-
-} // folly
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "folly/experimental/wangle/bootstrap/ServerBootstrap.h"
-#include "folly/experimental/wangle/bootstrap/ClientBootstrap.h"
-#include "folly/experimental/wangle/channel/ChannelHandler.h"
-
-#include <glog/logging.h>
-#include <gtest/gtest.h>
-
-using namespace folly::wangle;
-using namespace folly;
-
-typedef ChannelPipeline<IOBufQueue&, std::unique_ptr<IOBuf>> Pipeline;
-
-class TestServer : public ServerBootstrap<Pipeline> {
- Pipeline* newPipeline(std::shared_ptr<AsyncSocket>) {
- return nullptr;
- }
-};
-
-class TestClient : public ClientBootstrap<Pipeline> {
- Pipeline* newPipeline(std::shared_ptr<AsyncSocket> sock) {
- CHECK(sock->good());
-
- // We probably aren't connected immedately, check after a small delay
- EventBaseManager::get()->getEventBase()->runAfterDelay([sock](){
- CHECK(sock->readable());
- }, 100);
- return nullptr;
- }
-};
-
-class TestPipelineFactory : public PipelineFactory<Pipeline> {
- public:
- Pipeline* newPipeline(std::shared_ptr<AsyncSocket> sock) {
- pipelines++;
- return new Pipeline();
- }
- std::atomic<int> pipelines{0};
-};
-
-TEST(Bootstrap, Basic) {
- TestServer server;
- TestClient client;
-}
-
-TEST(Bootstrap, ServerWithPipeline) {
- TestServer server;
- server.childPipeline(std::make_shared<TestPipelineFactory>());
- server.bind(0);
- server.stop();
-}
-
-TEST(Bootstrap, ClientServerTest) {
- TestServer server;
- auto factory = std::make_shared<TestPipelineFactory>();
- server.childPipeline(factory);
- server.bind(0);
- auto base = EventBaseManager::get()->getEventBase();
-
- SocketAddress address;
- server.getSockets()[0]->getAddress(&address);
-
- TestClient client;
- client.connect(address);
- base->loop();
- server.stop();
-
- CHECK(factory->pipelines == 1);
-}
-
-TEST(Bootstrap, ClientConnectionManagerTest) {
- // Create a single IO thread, and verify that
- // client connections are pooled properly
-
- TestServer server;
- auto factory = std::make_shared<TestPipelineFactory>();
- server.childPipeline(factory);
- server.group(std::make_shared<IOThreadPoolExecutor>(1));
- server.bind(0);
- auto base = EventBaseManager::get()->getEventBase();
-
- SocketAddress address;
- server.getSockets()[0]->getAddress(&address);
-
- TestClient client;
- client.connect(address);
-
- TestClient client2;
- client2.connect(address);
-
- base->loop();
- server.stop();
-
- CHECK(factory->pipelines == 2);
-}
-
-TEST(Bootstrap, ServerAcceptGroupTest) {
- // Verify that server is using the accept IO group
-
- TestServer server;
- auto factory = std::make_shared<TestPipelineFactory>();
- server.childPipeline(factory);
- server.group(std::make_shared<IOThreadPoolExecutor>(1), nullptr);
- server.bind(0);
-
- SocketAddress address;
- server.getSockets()[0]->getAddress(&address);
-
- boost::barrier barrier(2);
- auto thread = std::thread([&](){
- TestClient client;
- client.connect(address);
- EventBaseManager::get()->getEventBase()->loop();
- barrier.wait();
- });
- barrier.wait();
- server.stop();
- thread.join();
-
- CHECK(factory->pipelines == 1);
-}
-
-TEST(Bootstrap, ServerAcceptGroup2Test) {
- // Verify that server is using the accept IO group
-
- // Check if reuse port is supported, if not, don't run this test
- try {
- EventBase base;
- auto serverSocket = AsyncServerSocket::newSocket(&base);
- serverSocket->bind(0);
- serverSocket->listen(0);
- serverSocket->startAccepting();
- serverSocket->setReusePortEnabled(true);
- serverSocket->stopAccepting();
- } catch(...) {
- LOG(INFO) << "Reuse port probably not supported";
- return;
- }
-
- TestServer server;
- auto factory = std::make_shared<TestPipelineFactory>();
- server.childPipeline(factory);
- server.group(std::make_shared<IOThreadPoolExecutor>(4), nullptr);
- server.bind(0);
-
- SocketAddress address;
- server.getSockets()[0]->getAddress(&address);
-
- TestClient client;
- client.connect(address);
- EventBaseManager::get()->getEventBase()->loop();
-
- server.stop();
-
- CHECK(factory->pipelines == 1);
-}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#pragma once
-
-#include <folly/experimental/wangle/channel/ChannelPipeline.h>
-
-namespace folly {
-
-/*
- * A thin wrapper around ChannelPipeline and AsyncSocket to match
- * ServerBootstrap. On connect() a new pipeline is created.
- */
-template <typename Pipeline>
-class ClientBootstrap {
- public:
- ClientBootstrap() {
- }
- ClientBootstrap* bind(int port) {
- port_ = port;
- return this;
- }
- ClientBootstrap* connect(SocketAddress address) {
- pipeline_.reset(
- newPipeline(
- AsyncSocket::newSocket(EventBaseManager::get()->getEventBase(), address)
- ));
- return this;
- }
-
- virtual ~ClientBootstrap() {}
-
- protected:
- std::unique_ptr<Pipeline,
- folly::DelayedDestruction::Destructor> pipeline_;
-
- int port_;
-
- virtual Pipeline* newPipeline(std::shared_ptr<AsyncSocket> socket) = 0;
-};
-
-} // namespace
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#pragma once
-
-#include <folly/experimental/wangle/acceptor/Acceptor.h>
-#include <folly/io/async/EventBaseManager.h>
-#include <folly/experimental/wangle/concurrent/IOThreadPoolExecutor.h>
-#include <folly/experimental/wangle/ManagedConnection.h>
-#include <folly/experimental/wangle/channel/ChannelPipeline.h>
-
-namespace folly {
-
-template <typename Pipeline>
-class ServerAcceptor : public Acceptor {
- typedef std::unique_ptr<Pipeline,
- folly::DelayedDestruction::Destructor> PipelinePtr;
-
- class ServerConnection : public wangle::ManagedConnection {
- public:
- explicit ServerConnection(PipelinePtr pipeline)
- : pipeline_(std::move(pipeline)) {}
-
- ~ServerConnection() {
- }
-
- void timeoutExpired() noexcept {
- }
-
- void describe(std::ostream& os) const {}
- bool isBusy() const {
- return false;
- }
- void notifyPendingShutdown() {}
- void closeWhenIdle() {}
- void dropConnection() {}
- void dumpConnectionState(uint8_t loglevel) {}
- private:
- PipelinePtr pipeline_;
- };
-
- public:
- explicit ServerAcceptor(
- std::shared_ptr<PipelineFactory<Pipeline>> pipelineFactory)
- : Acceptor(ServerSocketConfig())
- , pipelineFactory_(pipelineFactory) {
- Acceptor::init(nullptr, &base_);
- }
-
- /* See Acceptor::onNewConnection for details */
- void onNewConnection(
- AsyncSocket::UniquePtr transport, const SocketAddress* address,
- const std::string& nextProtocolName, const TransportInfo& tinfo) {
-
- std::unique_ptr<Pipeline,
- folly::DelayedDestruction::Destructor>
- pipeline(pipelineFactory_->newPipeline(
- std::shared_ptr<AsyncSocket>(
- transport.release(),
- folly::DelayedDestruction::Destructor())));
- auto connection = new ServerConnection(std::move(pipeline));
- Acceptor::addConnection(connection);
- }
-
- ~ServerAcceptor() {
- Acceptor::dropAllConnections();
- }
-
- private:
- EventBase base_;
-
- std::shared_ptr<PipelineFactory<Pipeline>> pipelineFactory_;
-};
-
-template <typename Pipeline>
-class ServerAcceptorFactory : public AcceptorFactory {
- public:
- explicit ServerAcceptorFactory(
- std::shared_ptr<PipelineFactory<Pipeline>> factory)
- : factory_(factory) {}
-
- std::shared_ptr<Acceptor> newAcceptor() {
- return std::make_shared<ServerAcceptor<Pipeline>>(factory_);
- }
- private:
- std::shared_ptr<PipelineFactory<Pipeline>> factory_;
-};
-
-class ServerWorkerFactory : public folly::wangle::ThreadFactory {
- public:
- explicit ServerWorkerFactory(std::shared_ptr<AcceptorFactory> acceptorFactory)
- : internalFactory_(
- std::make_shared<folly::wangle::NamedThreadFactory>("BootstrapWorker"))
- , acceptorFactory_(acceptorFactory)
- {}
- virtual std::thread newThread(folly::Func&& func) override;
-
- void setInternalFactory(
- std::shared_ptr<folly::wangle::NamedThreadFactory> internalFactory);
- void setNamePrefix(folly::StringPiece prefix);
-
- template <typename F>
- void forEachWorker(F&& f);
-
- private:
- std::shared_ptr<folly::wangle::NamedThreadFactory> internalFactory_;
- folly::RWSpinLock workersLock_;
- std::map<int32_t, std::shared_ptr<Acceptor>> workers_;
- int32_t nextWorkerId_{0};
-
- std::shared_ptr<AcceptorFactory> acceptorFactory_;
-};
-
-template <typename F>
-void ServerWorkerFactory::forEachWorker(F&& f) {
- folly::RWSpinLock::ReadHolder guard(workersLock_);
- for (const auto& kv : workers_) {
- f(kv.second.get());
- }
-}
-
-} // namespace
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#include <folly/experimental/wangle/bootstrap/ServerBootstrap.h>
-#include <folly/experimental/wangle/concurrent/NamedThreadFactory.h>
-#include <folly/io/async/EventBaseManager.h>
-
-namespace folly {
-
-std::thread ServerWorkerFactory::newThread(
- folly::Func&& func) {
- auto id = nextWorkerId_++;
- auto worker = acceptorFactory_->newAcceptor();
- {
- folly::RWSpinLock::WriteHolder guard(workersLock_);
- workers_.insert({id, worker});
- }
- return internalFactory_->newThread([=](){
- EventBaseManager::get()->setEventBase(worker->getEventBase(), false);
- func();
- EventBaseManager::get()->clearEventBase();
-
- worker->drainAllConnections();
- {
- folly::RWSpinLock::WriteHolder guard(workersLock_);
- workers_.erase(id);
- }
- });
-}
-
-void ServerWorkerFactory::setInternalFactory(
- std::shared_ptr<wangle::NamedThreadFactory> internalFactory) {
- CHECK(workers_.empty());
- internalFactory_ = internalFactory;
-}
-
-void ServerWorkerFactory::setNamePrefix(folly::StringPiece prefix) {
- CHECK(workers_.empty());
- internalFactory_->setNamePrefix(prefix);
-}
-
-} // namespace
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#pragma once
-
-#include <folly/experimental/wangle/bootstrap/ServerBootstrap-inl.h>
-#include <boost/thread.hpp>
-
-namespace folly {
-
-/*
- * ServerBootstrap is a parent class intended to set up a
- * high-performance TCP accepting server. It will manage a pool of
- * accepting threads, any number of accepting sockets, a pool of
- * IO-worker threads, and connection pool for each IO thread for you.
- *
- * The output is given as a ChannelPipeline template: given a
- * PipelineFactory, it will create a new pipeline for each connection,
- * and your server can handle the incoming bytes.
- *
- * BACKWARDS COMPATIBLITY: for servers already taking a pool of
- * Acceptor objects, an AcceptorFactory can be given directly instead
- * of a pipeline factory.
- */
-template <typename Pipeline>
-class ServerBootstrap {
- public:
- /* TODO(davejwatson)
- *
- * If there is any work to be done BEFORE handing the work to IO
- * threads, this handler is where the pipeline to do it would be
- * set.
- *
- * This could be used for things like logging, load balancing, or
- * advanced load balancing on IO threads. Netty also provides this.
- */
- ServerBootstrap* handler() {
- return this;
- }
-
- /*
- * BACKWARDS COMPATIBILITY - an acceptor factory can be set. Your
- * Acceptor is responsible for managing the connection pool.
- *
- * @param childHandler - acceptor factory to call for each IO thread
- */
- ServerBootstrap* childHandler(std::shared_ptr<AcceptorFactory> childHandler) {
- acceptorFactory_ = childHandler;
- return this;
- }
-
- /*
- * Set a pipeline factory that will be called for each new connection
- *
- * @param factory pipeline factory to use for each new connection
- */
- ServerBootstrap* childPipeline(
- std::shared_ptr<PipelineFactory<Pipeline>> factory) {
- pipelineFactory_ = factory;
- return this;
- }
-
- /*
- * Set the IO executor. If not set, a default one will be created
- * with one thread per core.
- *
- * @param io_group - io executor to use for IO threads.
- */
- ServerBootstrap* group(
- std::shared_ptr<folly::wangle::IOThreadPoolExecutor> io_group) {
- return group(nullptr, io_group);
- }
-
- /*
- * Set the acceptor executor, and IO executor.
- *
- * If no acceptor executor is set, a single thread will be created for accepts
- * If no IO executor is set, a default of one thread per core will be created
- *
- * @param group - acceptor executor to use for acceptor threads.
- * @param io_group - io executor to use for IO threads.
- */
- ServerBootstrap* group(
- std::shared_ptr<folly::wangle::IOThreadPoolExecutor> accept_group,
- std::shared_ptr<wangle::IOThreadPoolExecutor> io_group) {
- if (!accept_group) {
- accept_group = std::make_shared<folly::wangle::IOThreadPoolExecutor>(
- 1, std::make_shared<wangle::NamedThreadFactory>("Acceptor Thread"));
- }
- if (!io_group) {
- io_group = std::make_shared<folly::wangle::IOThreadPoolExecutor>(
- 32, std::make_shared<wangle::NamedThreadFactory>("IO Thread"));
- }
- auto factoryBase = io_group->getThreadFactory();
- CHECK(factoryBase);
- auto factory = std::dynamic_pointer_cast<folly::wangle::NamedThreadFactory>(
- factoryBase);
- CHECK(factory); // Must be named thread factory
-
- CHECK(acceptorFactory_ || pipelineFactory_);
-
- if (acceptorFactory_) {
- workerFactory_ = std::make_shared<ServerWorkerFactory>(
- acceptorFactory_);
- } else {
- workerFactory_ = std::make_shared<ServerWorkerFactory>(
- std::make_shared<ServerAcceptorFactory<Pipeline>>(pipelineFactory_));
- }
- workerFactory_->setInternalFactory(factory);
-
- acceptor_group_ = accept_group;
- io_group_ = io_group;
-
- auto numThreads = io_group_->numThreads();
- io_group_->setNumThreads(0);
- io_group_->setThreadFactory(workerFactory_);
- io_group_->setNumThreads(numThreads);
-
- return this;
- }
-
- /*
- * Bind to a port and start listening.
- * One of childPipeline or childHandler must be called before bind
- *
- * @param port Port to listen on
- */
- void bind(int port) {
- // TODO take existing socket
-
- if (!workerFactory_) {
- group(nullptr);
- }
-
- bool reusePort = false;
- if (acceptor_group_->numThreads() > 1) {
- reusePort = true;
- }
-
- std::mutex sock_lock;
- std::vector<std::shared_ptr<folly::AsyncServerSocket>> new_sockets;
-
- auto startupFunc = [&](std::shared_ptr<boost::barrier> barrier){
- auto socket = folly::AsyncServerSocket::newSocket();
- sock_lock.lock();
- new_sockets.push_back(socket);
- sock_lock.unlock();
- socket->setReusePortEnabled(reusePort);
- socket->attachEventBase(EventBaseManager::get()->getEventBase());
- socket->bind(port);
- // TODO Take ServerSocketConfig
- socket->listen(1024);
- socket->startAccepting();
-
- if (port == 0) {
- SocketAddress address;
- socket->getAddress(&address);
- port = address.getPort();
- }
-
- barrier->wait();
- };
-
- auto bind0 = std::make_shared<boost::barrier>(2);
- acceptor_group_->add(std::bind(startupFunc, bind0));
- bind0->wait();
-
- auto barrier = std::make_shared<boost::barrier>(acceptor_group_->numThreads());
- for (int i = 1; i < acceptor_group_->numThreads(); i++) {
- acceptor_group_->add(std::bind(startupFunc, barrier));
- }
- barrier->wait();
-
- // Startup all the threads
- for(auto socket : new_sockets) {
- workerFactory_->forEachWorker([this, socket](Acceptor* worker){
- socket->getEventBase()->runInEventBaseThread([this, worker, socket](){
- socket->addAcceptCallback(worker, worker->getEventBase());
- });
- });
- }
-
- for (auto& socket : new_sockets) {
- sockets_.push_back(socket);
- }
- }
-
- /*
- * Stop listening on all sockets.
- */
- void stop() {
- auto barrier = std::make_shared<boost::barrier>(sockets_.size() + 1);
- for (auto socket : sockets_) {
- socket->getEventBase()->runInEventBaseThread([barrier, socket]() {
- socket->stopAccepting();
- socket->detachEventBase();
- barrier->wait();
- });
- }
- barrier->wait();
- sockets_.clear();
-
- acceptor_group_->join();
- io_group_->join();
- }
-
- /*
- * Get the list of listening sockets
- */
- std::vector<std::shared_ptr<folly::AsyncServerSocket>>&
- getSockets() {
- return sockets_;
- }
-
- private:
- std::shared_ptr<wangle::IOThreadPoolExecutor> acceptor_group_;
- std::shared_ptr<wangle::IOThreadPoolExecutor> io_group_;
-
- std::shared_ptr<ServerWorkerFactory> workerFactory_;
- std::vector<std::shared_ptr<folly::AsyncServerSocket>> sockets_;
-
- std::shared_ptr<AcceptorFactory> acceptorFactory_;
- std::shared_ptr<PipelineFactory<Pipeline>> pipelineFactory_;
-};
-
-} // namespace
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <folly/experimental/wangle/channel/ChannelHandler.h>
-#include <folly/io/async/AsyncSocket.h>
-#include <folly/io/async/EventBase.h>
-#include <folly/io/async/EventBaseManager.h>
-#include <folly/io/IOBuf.h>
-#include <folly/io/IOBufQueue.h>
-
-namespace folly { namespace wangle {
-
-class AsyncSocketHandler
- : public folly::wangle::BytesToBytesHandler,
- public AsyncSocket::ReadCallback {
- public:
- explicit AsyncSocketHandler(
- std::shared_ptr<AsyncSocket> socket)
- : socket_(std::move(socket)) {}
-
- AsyncSocketHandler(AsyncSocketHandler&&) = default;
-
- ~AsyncSocketHandler() {
- if (socket_) {
- detachReadCallback();
- }
- }
-
- void attachReadCallback() {
- socket_->setReadCB(socket_->good() ? this : nullptr);
- }
-
- void detachReadCallback() {
- if (socket_->getReadCallback() == this) {
- socket_->setReadCB(nullptr);
- }
- }
-
- void attachEventBase(folly::EventBase* eventBase) {
- if (eventBase && !socket_->getEventBase()) {
- socket_->attachEventBase(eventBase);
- }
- }
-
- void detachEventBase() {
- detachReadCallback();
- if (socket_->getEventBase()) {
- socket_->detachEventBase();
- }
- }
-
- void attachPipeline(Context* ctx) override {
- CHECK(!ctx_);
- ctx_ = ctx;
- }
-
- folly::wangle::Future<void> write(
- Context* ctx,
- std::unique_ptr<folly::IOBuf> buf) override {
- if (UNLIKELY(!buf)) {
- return folly::wangle::makeFuture();
- }
-
- if (!socket_->good()) {
- VLOG(5) << "socket is closed in write()";
- return folly::wangle::makeFuture<void>(AsyncSocketException(
- AsyncSocketException::AsyncSocketExceptionType::NOT_OPEN,
- "socket is closed in write()"));
- }
-
- auto cb = new WriteCallback();
- auto future = cb->promise_.getFuture();
- socket_->writeChain(cb, std::move(buf), ctx->getWriteFlags());
- return future;
- };
-
- folly::wangle::Future<void> close(Context* ctx) {
- if (socket_) {
- detachReadCallback();
- socket_->closeNow();
- }
- return folly::wangle::makeFuture();
- }
-
- // Must override to avoid warnings about hidden overloaded virtual due to
- // AsyncSocket::ReadCallback::readEOF()
- void readEOF(Context* ctx) override {
- ctx->fireReadEOF();
- }
-
- void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
- const auto readBufferSettings = ctx_->getReadBufferSettings();
- const auto ret = bufQueue_.preallocate(
- readBufferSettings.first,
- readBufferSettings.second);
- *bufReturn = ret.first;
- *lenReturn = ret.second;
- }
-
- void readDataAvailable(size_t len) noexcept override {
- bufQueue_.postallocate(len);
- ctx_->fireRead(bufQueue_);
- }
-
- void readEOF() noexcept override {
- ctx_->fireReadEOF();
- }
-
- void readErr(const AsyncSocketException& ex)
- noexcept override {
- ctx_->fireReadException(make_exception_wrapper<AsyncSocketException>(ex));
- }
-
- private:
- class WriteCallback : private AsyncSocket::WriteCallback {
- void writeSuccess() noexcept override {
- promise_.setValue();
- delete this;
- }
-
- void writeErr(size_t bytesWritten,
- const AsyncSocketException& ex)
- noexcept override {
- promise_.setException(ex);
- delete this;
- }
-
- private:
- friend class AsyncSocketHandler;
- folly::wangle::Promise<void> promise_;
- };
-
- Context* ctx_{nullptr};
- folly::IOBufQueue bufQueue_;
- std::shared_ptr<AsyncSocket> socket_{nullptr};
-};
-
-}}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <folly/wangle/futures/Future.h>
-#include <folly/experimental/wangle/channel/ChannelPipeline.h>
-#include <folly/io/IOBuf.h>
-#include <folly/io/IOBufQueue.h>
-
-namespace folly { namespace wangle {
-
-template <class Rin, class Rout = Rin, class Win = Rout, class Wout = Rin>
-class ChannelHandler {
- public:
- typedef Rin rin;
- typedef Rout rout;
- typedef Win win;
- typedef Wout wout;
- typedef ChannelHandlerContext<Rout, Wout> Context;
- virtual ~ChannelHandler() {}
-
- virtual void read(Context* ctx, Rin msg) = 0;
- virtual void readEOF(Context* ctx) {
- ctx->fireReadEOF();
- }
- virtual void readException(Context* ctx, exception_wrapper e) {
- ctx->fireReadException(std::move(e));
- }
-
- virtual Future<void> write(Context* ctx, Win msg) = 0;
- virtual Future<void> close(Context* ctx) {
- return ctx->fireClose();
- }
-
- virtual void attachPipeline(Context* ctx) {}
- virtual void attachTransport(Context* ctx) {}
-
- virtual void detachPipeline(Context* ctx) {}
- virtual void detachTransport(Context* ctx) {}
-
- /*
- // Other sorts of things we might want, all shamelessly stolen from Netty
- // inbound
- virtual void exceptionCaught(
- ChannelHandlerContext* ctx,
- exception_wrapper e) {}
- virtual void channelRegistered(ChannelHandlerContext* ctx) {}
- virtual void channelUnregistered(ChannelHandlerContext* ctx) {}
- virtual void channelActive(ChannelHandlerContext* ctx) {}
- virtual void channelInactive(ChannelHandlerContext* ctx) {}
- virtual void channelReadComplete(ChannelHandlerContext* ctx) {}
- virtual void userEventTriggered(ChannelHandlerContext* ctx, void* evt) {}
- virtual void channelWritabilityChanged(ChannelHandlerContext* ctx) {}
-
- // outbound
- virtual Future<void> bind(
- ChannelHandlerContext* ctx,
- SocketAddress localAddress) {}
- virtual Future<void> connect(
- ChannelHandlerContext* ctx,
- SocketAddress remoteAddress, SocketAddress localAddress) {}
- virtual Future<void> disconnect(ChannelHandlerContext* ctx) {}
- virtual Future<void> deregister(ChannelHandlerContext* ctx) {}
- virtual Future<void> read(ChannelHandlerContext* ctx) {}
- virtual void flush(ChannelHandlerContext* ctx) {}
- */
-};
-
-template <class R, class W = R>
-class ChannelHandlerAdapter : public ChannelHandler<R, R, W, W> {
- public:
- typedef typename ChannelHandler<R, R, W, W>::Context Context;
-
- void read(Context* ctx, R msg) override {
- ctx->fireRead(std::forward<R>(msg));
- }
-
- Future<void> write(Context* ctx, W msg) override {
- return ctx->fireWrite(std::forward<W>(msg));
- }
-};
-
-typedef ChannelHandlerAdapter<IOBufQueue&, std::unique_ptr<IOBuf>>
-BytesToBytesHandler;
-
-template <class Handler, bool Shared = true>
-class ChannelHandlerPtr : public ChannelHandler<
- typename Handler::rin,
- typename Handler::rout,
- typename Handler::win,
- typename Handler::wout> {
- public:
- typedef typename std::conditional<
- Shared,
- std::shared_ptr<Handler>,
- Handler*>::type
- HandlerPtr;
-
- typedef typename Handler::Context Context;
-
- explicit ChannelHandlerPtr(HandlerPtr handler)
- : handler_(std::move(handler)) {}
-
- void setHandler(HandlerPtr handler) {
- if (handler == handler_) {
- return;
- }
- if (handler_ && ctx_) {
- handler_->detachPipeline(ctx_);
- }
- handler_ = std::move(handler);
- if (handler_ && ctx_) {
- handler_->attachPipeline(ctx_);
- if (ctx_->getTransport()) {
- handler_->attachTransport(ctx_);
- }
- }
- }
-
- void attachPipeline(Context* ctx) override {
- ctx_ = ctx;
- if (handler_) {
- handler_->attachPipeline(ctx_);
- }
- }
-
- void attachTransport(Context* ctx) override {
- ctx_ = ctx;
- if (handler_) {
- handler_->attachTransport(ctx_);
- }
- }
-
- void detachPipeline(Context* ctx) override {
- ctx_ = ctx;
- if (handler_) {
- handler_->detachPipeline(ctx_);
- }
- }
-
- void detachTransport(Context* ctx) override {
- ctx_ = ctx;
- if (handler_) {
- handler_->detachTransport(ctx_);
- }
- }
-
- void read(Context* ctx, typename Handler::rin msg) override {
- DCHECK(handler_);
- handler_->read(ctx, std::forward<typename Handler::rin>(msg));
- }
-
- void readEOF(Context* ctx) override {
- DCHECK(handler_);
- handler_->readEOF(ctx);
- }
-
- void readException(Context* ctx, exception_wrapper e) override {
- DCHECK(handler_);
- handler_->readException(ctx, std::move(e));
- }
-
- Future<void> write(Context* ctx, typename Handler::win msg) override {
- DCHECK(handler_);
- return handler_->write(ctx, std::forward<typename Handler::win>(msg));
- }
-
- Future<void> close(Context* ctx) override {
- DCHECK(handler_);
- return handler_->close(ctx);
- }
-
- private:
- Context* ctx_;
- HandlerPtr handler_;
-};
-
-}}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <folly/io/async/AsyncTransport.h>
-#include <folly/wangle/futures/Future.h>
-#include <folly/ExceptionWrapper.h>
-
-namespace folly { namespace wangle {
-
-template <class In, class Out>
-class ChannelHandlerContext {
- public:
- virtual ~ChannelHandlerContext() {}
-
- virtual void fireRead(In msg) = 0;
- virtual void fireReadEOF() = 0;
- virtual void fireReadException(exception_wrapper e) = 0;
-
- virtual Future<void> fireWrite(Out msg) = 0;
- virtual Future<void> fireClose() = 0;
-
- virtual std::shared_ptr<AsyncTransport> getTransport() = 0;
-
- virtual void setWriteFlags(WriteFlags flags) = 0;
- virtual WriteFlags getWriteFlags() = 0;
-
- virtual void setReadBufferSettings(
- uint64_t minAvailable,
- uint64_t allocationSize) = 0;
- virtual std::pair<uint64_t, uint64_t> getReadBufferSettings() = 0;
-
- /* TODO
- template <class H>
- virtual void addHandlerBefore(H&&) {}
- template <class H>
- virtual void addHandlerAfter(H&&) {}
- template <class H>
- virtual void replaceHandler(H&&) {}
- virtual void removeHandler() {}
- */
-};
-
-class PipelineContext {
- public:
- virtual ~PipelineContext() {}
-
- virtual void attachTransport() = 0;
- virtual void detachTransport() = 0;
-
- void link(PipelineContext* other) {
- setNextIn(other);
- other->setNextOut(this);
- }
-
- protected:
- virtual void setNextIn(PipelineContext* ctx) = 0;
- virtual void setNextOut(PipelineContext* ctx) = 0;
-};
-
-template <class In>
-class InboundChannelHandlerContext {
- public:
- virtual ~InboundChannelHandlerContext() {}
- virtual void read(In msg) = 0;
- virtual void readEOF() = 0;
- virtual void readException(exception_wrapper e) = 0;
-};
-
-template <class Out>
-class OutboundChannelHandlerContext {
- public:
- virtual ~OutboundChannelHandlerContext() {}
- virtual Future<void> write(Out msg) = 0;
- virtual Future<void> close() = 0;
-};
-
-template <class P, class H>
-class ContextImpl : public ChannelHandlerContext<typename H::rout,
- typename H::wout>,
- public InboundChannelHandlerContext<typename H::rin>,
- public OutboundChannelHandlerContext<typename H::win>,
- public PipelineContext {
- public:
- typedef typename H::rin Rin;
- typedef typename H::rout Rout;
- typedef typename H::win Win;
- typedef typename H::wout Wout;
-
- template <class HandlerArg>
- explicit ContextImpl(P* pipeline, HandlerArg&& handlerArg)
- : pipeline_(pipeline),
- handler_(std::forward<HandlerArg>(handlerArg)) {
- handler_.attachPipeline(this);
- }
-
- ~ContextImpl() {
- handler_.detachPipeline(this);
- }
-
- H* getHandler() {
- return &handler_;
- }
-
- // PipelineContext overrides
- void setNextIn(PipelineContext* ctx) override {
- auto nextIn = dynamic_cast<InboundChannelHandlerContext<Rout>*>(ctx);
- if (nextIn) {
- nextIn_ = nextIn;
- } else {
- throw std::invalid_argument("wrong type in setNextIn");
- }
- }
-
- void setNextOut(PipelineContext* ctx) override {
- auto nextOut = dynamic_cast<OutboundChannelHandlerContext<Wout>*>(ctx);
- if (nextOut) {
- nextOut_ = nextOut;
- } else {
- throw std::invalid_argument("wrong type in setNextOut");
- }
- }
-
- void attachTransport() override {
- typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
- handler_.attachTransport(this);
- }
-
- void detachTransport() override {
- typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
- handler_.detachTransport(this);
- }
-
- // ChannelHandlerContext overrides
- void fireRead(Rout msg) override {
- typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
- if (nextIn_) {
- nextIn_->read(std::forward<Rout>(msg));
- } else {
- LOG(WARNING) << "read reached end of pipeline";
- }
- }
-
- void fireReadEOF() override {
- typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
- if (nextIn_) {
- nextIn_->readEOF();
- } else {
- LOG(WARNING) << "readEOF reached end of pipeline";
- }
- }
-
- void fireReadException(exception_wrapper e) override {
- typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
- if (nextIn_) {
- nextIn_->readException(std::move(e));
- } else {
- LOG(WARNING) << "readException reached end of pipeline";
- }
- }
-
- Future<void> fireWrite(Wout msg) override {
- typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
- if (nextOut_) {
- return nextOut_->write(std::forward<Wout>(msg));
- } else {
- LOG(WARNING) << "write reached end of pipeline";
- return makeFuture();
- }
- }
-
- Future<void> fireClose() override {
- typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
- if (nextOut_) {
- return nextOut_->close();
- } else {
- LOG(WARNING) << "close reached end of pipeline";
- return makeFuture();
- }
- }
-
- std::shared_ptr<AsyncTransport> getTransport() override {
- return pipeline_->getTransport();
- }
-
- void setWriteFlags(WriteFlags flags) override {
- pipeline_->setWriteFlags(flags);
- }
-
- WriteFlags getWriteFlags() override {
- return pipeline_->getWriteFlags();
- }
-
- void setReadBufferSettings(
- uint64_t minAvailable,
- uint64_t allocationSize) override {
- pipeline_->setReadBufferSettings(minAvailable, allocationSize);
- }
-
- std::pair<uint64_t, uint64_t> getReadBufferSettings() override {
- return pipeline_->getReadBufferSettings();
- }
-
- // InboundChannelHandlerContext overrides
- void read(Rin msg) override {
- typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
- handler_.read(this, std::forward<Rin>(msg));
- }
-
- void readEOF() override {
- typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
- handler_.readEOF(this);
- }
-
- void readException(exception_wrapper e) override {
- typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
- handler_.readException(this, std::move(e));
- }
-
- // OutboundChannelHandlerContext overrides
- Future<void> write(Win msg) override {
- typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
- return handler_.write(this, std::forward<Win>(msg));
- }
-
- Future<void> close() override {
- typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
- return handler_.close(this);
- }
-
- private:
- P* pipeline_;
- H handler_;
- InboundChannelHandlerContext<Rout>* nextIn_{nullptr};
- OutboundChannelHandlerContext<Wout>* nextOut_{nullptr};
-};
-
-}}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <folly/experimental/wangle/channel/ChannelHandlerContext.h>
-#include <folly/wangle/futures/Future.h>
-#include <folly/io/async/AsyncTransport.h>
-#include <folly/io/async/DelayedDestruction.h>
-#include <folly/ExceptionWrapper.h>
-#include <folly/Memory.h>
-#include <glog/logging.h>
-
-namespace folly { namespace wangle {
-
-/*
- * R is the inbound type, i.e. inbound calls start with pipeline.read(R)
- * W is the outbound type, i.e. outbound calls start with pipeline.write(W)
- */
-template <class R, class W, class... Handlers>
-class ChannelPipeline;
-
-template <class R, class W>
-class ChannelPipeline<R, W> : public DelayedDestruction {
- public:
- ChannelPipeline() {}
- ~ChannelPipeline() {}
-
- std::shared_ptr<AsyncTransport> getTransport() {
- return transport_;
- }
-
- void setWriteFlags(WriteFlags flags) {
- writeFlags_ = flags;
- }
-
- WriteFlags getWriteFlags() {
- return writeFlags_;
- }
-
- void setReadBufferSettings(uint64_t minAvailable, uint64_t allocationSize) {
- readBufferSettings_ = std::make_pair(minAvailable, allocationSize);
- }
-
- std::pair<uint64_t, uint64_t> getReadBufferSettings() {
- return readBufferSettings_;
- }
-
- void read(R msg) {
- front_->read(std::forward<R>(msg));
- }
-
- void readEOF() {
- front_->readEOF();
- }
-
- void readException(exception_wrapper e) {
- front_->readException(std::move(e));
- }
-
- Future<void> write(W msg) {
- return back_->write(std::forward<W>(msg));
- }
-
- Future<void> close() {
- return back_->close();
- }
-
- template <class H>
- ChannelPipeline& addBack(H&& handler) {
- ctxs_.push_back(folly::make_unique<ContextImpl<ChannelPipeline, H>>(
- this, std::forward<H>(handler)));
- return *this;
- }
-
- template <class H>
- ChannelPipeline& addFront(H&& handler) {
- ctxs_.insert(
- ctxs_.begin(),
- folly::make_unique<ContextImpl<ChannelPipeline, H>>(
- this,
- std::forward<H>(handler)));
- return *this;
- }
-
- template <class H>
- H* getHandler(int i) {
- auto ctx = dynamic_cast<ContextImpl<ChannelPipeline, H>*>(ctxs_[i].get());
- CHECK(ctx);
- return ctx->getHandler();
- }
-
- void finalize() {
- finalizeHelper();
- InboundChannelHandlerContext<R>* front;
- front_ = dynamic_cast<InboundChannelHandlerContext<R>*>(
- ctxs_.front().get());
- if (!front_) {
- throw std::invalid_argument("wrong type for first handler");
- }
- }
-
- protected:
- explicit ChannelPipeline(bool shouldFinalize) {
- CHECK(!shouldFinalize);
- }
-
- void finalizeHelper() {
- if (ctxs_.empty()) {
- return;
- }
-
- for (int i = 0; i < ctxs_.size() - 1; i++) {
- ctxs_[i]->link(ctxs_[i+1].get());
- }
-
- back_ = dynamic_cast<OutboundChannelHandlerContext<W>*>(ctxs_.back().get());
- if (!back_) {
- throw std::invalid_argument("wrong type for last handler");
- }
- }
-
- PipelineContext* getLocalFront() {
- return ctxs_.empty() ? nullptr : ctxs_.front().get();
- }
-
- static const bool is_end{true};
-
- std::shared_ptr<AsyncTransport> transport_;
- WriteFlags writeFlags_{WriteFlags::NONE};
- std::pair<uint64_t, uint64_t> readBufferSettings_{2048, 2048};
-
- void attachPipeline() {}
-
- void attachTransport(
- std::shared_ptr<AsyncTransport> transport) {
- transport_ = std::move(transport);
- }
-
- void detachTransport() {
- transport_ = nullptr;
- }
-
- OutboundChannelHandlerContext<W>* back_{nullptr};
-
- private:
- InboundChannelHandlerContext<R>* front_{nullptr};
- std::vector<std::unique_ptr<PipelineContext>> ctxs_;
-};
-
-template <class R, class W, class Handler, class... Handlers>
-class ChannelPipeline<R, W, Handler, Handlers...>
- : public ChannelPipeline<R, W, Handlers...> {
- protected:
- template <class HandlerArg, class... HandlersArgs>
- ChannelPipeline(
- bool shouldFinalize,
- HandlerArg&& handlerArg,
- HandlersArgs&&... handlersArgs)
- : ChannelPipeline<R, W, Handlers...>(
- false,
- std::forward<HandlersArgs>(handlersArgs)...),
- ctx_(this, std::forward<HandlerArg>(handlerArg)) {
- if (shouldFinalize) {
- finalize();
- }
- }
-
- public:
- template <class... HandlersArgs>
- explicit ChannelPipeline(HandlersArgs&&... handlersArgs)
- : ChannelPipeline(true, std::forward<HandlersArgs>(handlersArgs)...) {}
-
- ~ChannelPipeline() {}
-
- void destroy() override { }
-
- void read(R msg) {
- typename ChannelPipeline<R, W>::DestructorGuard dg(
- static_cast<DelayedDestruction*>(this));
- front_->read(std::forward<R>(msg));
- }
-
- void readEOF() {
- typename ChannelPipeline<R, W>::DestructorGuard dg(
- static_cast<DelayedDestruction*>(this));
- front_->readEOF();
- }
-
- void readException(exception_wrapper e) {
- typename ChannelPipeline<R, W>::DestructorGuard dg(
- static_cast<DelayedDestruction*>(this));
- front_->readException(std::move(e));
- }
-
- Future<void> write(W msg) {
- typename ChannelPipeline<R, W>::DestructorGuard dg(
- static_cast<DelayedDestruction*>(this));
- return back_->write(std::forward<W>(msg));
- }
-
- Future<void> close() {
- typename ChannelPipeline<R, W>::DestructorGuard dg(
- static_cast<DelayedDestruction*>(this));
- return back_->close();
- }
-
- void attachTransport(
- std::shared_ptr<AsyncTransport> transport) {
- typename ChannelPipeline<R, W>::DestructorGuard dg(
- static_cast<DelayedDestruction*>(this));
- CHECK((!ChannelPipeline<R, W>::transport_));
- ChannelPipeline<R, W, Handlers...>::attachTransport(std::move(transport));
- forEachCtx([&](PipelineContext* ctx){
- ctx->attachTransport();
- });
- }
-
- void detachTransport() {
- typename ChannelPipeline<R, W>::DestructorGuard dg(
- static_cast<DelayedDestruction*>(this));
- ChannelPipeline<R, W, Handlers...>::detachTransport();
- forEachCtx([&](PipelineContext* ctx){
- ctx->detachTransport();
- });
- }
-
- std::shared_ptr<AsyncTransport> getTransport() {
- return ChannelPipeline<R, W>::transport_;
- }
-
- template <class H>
- ChannelPipeline& addBack(H&& handler) {
- ChannelPipeline<R, W>::addBack(std::move(handler));
- return *this;
- }
-
- template <class H>
- ChannelPipeline& addFront(H&& handler) {
- ctxs_.insert(
- ctxs_.begin(),
- folly::make_unique<ContextImpl<ChannelPipeline, H>>(
- this,
- std::move(handler)));
- return *this;
- }
-
- template <class H>
- H* getHandler(size_t i) {
- if (i > ctxs_.size()) {
- return ChannelPipeline<R, W, Handlers...>::template getHandler<H>(
- i - (ctxs_.size() + 1));
- } else {
- auto pctx = (i == ctxs_.size()) ? &ctx_ : ctxs_[i].get();
- auto ctx = dynamic_cast<ContextImpl<ChannelPipeline, H>*>(pctx);
- return ctx->getHandler();
- }
- }
-
- void finalize() {
- finalizeHelper();
- auto ctx = ctxs_.empty() ? &ctx_ : ctxs_.front().get();
- front_ = dynamic_cast<InboundChannelHandlerContext<R>*>(ctx);
- if (!front_) {
- throw std::invalid_argument("wrong type for first handler");
- }
- }
-
- protected:
- void finalizeHelper() {
- ChannelPipeline<R, W, Handlers...>::finalizeHelper();
- back_ = ChannelPipeline<R, W, Handlers...>::back_;
- if (!back_) {
- auto is_end = ChannelPipeline<R, W, Handlers...>::is_end;
- CHECK(is_end);
- back_ = dynamic_cast<OutboundChannelHandlerContext<W>*>(&ctx_);
- if (!back_) {
- throw std::invalid_argument("wrong type for last handler");
- }
- }
-
- if (!ctxs_.empty()) {
- for (int i = 0; i < ctxs_.size() - 1; i++) {
- ctxs_[i]->link(ctxs_[i+1].get());
- }
- ctxs_.back()->link(&ctx_);
- }
-
- auto nextFront = ChannelPipeline<R, W, Handlers...>::getLocalFront();
- if (nextFront) {
- ctx_.link(nextFront);
- }
- }
-
- PipelineContext* getLocalFront() {
- return ctxs_.empty() ? &ctx_ : ctxs_.front().get();
- }
-
- static const bool is_end{false};
- InboundChannelHandlerContext<R>* front_{nullptr};
- OutboundChannelHandlerContext<W>* back_{nullptr};
-
- private:
- template <class F>
- void forEachCtx(const F& func) {
- for (auto& ctx : ctxs_) {
- func(ctx.get());
- }
- func(&ctx_);
- }
-
- ContextImpl<ChannelPipeline, Handler> ctx_;
- std::vector<std::unique_ptr<PipelineContext>> ctxs_;
-};
-
-}}
-
-namespace folly {
-
-class AsyncSocket;
-
-template <typename Pipeline>
-class PipelineFactory {
- public:
- virtual Pipeline* newPipeline(std::shared_ptr<AsyncSocket>) = 0;
- virtual ~PipelineFactory() {}
-};
-
-}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <folly/experimental/wangle/channel/ChannelHandler.h>
-#include <folly/io/async/EventBase.h>
-#include <folly/io/async/EventBaseManager.h>
-#include <folly/io/IOBuf.h>
-#include <folly/io/IOBufQueue.h>
-
-namespace folly { namespace wangle {
-
-/*
- * OutputBufferingHandler buffers writes in order to minimize syscalls. The
- * transport will be written to once per event loop instead of on every write.
- */
-class OutputBufferingHandler : public BytesToBytesHandler,
- protected EventBase::LoopCallback {
- public:
- Future<void> write(Context* ctx, std::unique_ptr<IOBuf> buf) override {
- CHECK(buf);
- if (!queueSends_) {
- return ctx->fireWrite(std::move(buf));
- } else {
- ctx_ = ctx;
- // Delay sends to optimize for fewer syscalls
- if (!sends_) {
- DCHECK(!isLoopCallbackScheduled());
- // Buffer all the sends, and call writev once per event loop.
- sends_ = std::move(buf);
- ctx->getTransport()->getEventBase()->runInLoop(this);
- } else {
- DCHECK(isLoopCallbackScheduled());
- sends_->prependChain(std::move(buf));
- }
- Promise<void> p;
- auto f = p.getFuture();
- promises_.push_back(std::move(p));
- return f;
- }
- }
-
- void runLoopCallback() noexcept override {
- MoveWrapper<std::vector<Promise<void>>> promises(std::move(promises_));
- ctx_->fireWrite(std::move(sends_)).then([promises](Try<void>&& t) mutable {
- try {
- t.throwIfFailed();
- for (auto& p : *promises) {
- p.setValue();
- }
- } catch (...) {
- for (auto& p : *promises) {
- p.setException(std::current_exception());
- }
- }
- });
- }
-
- std::vector<Promise<void>> promises_;
- std::unique_ptr<IOBuf> sends_{nullptr};
- bool queueSends_{true};
- Context* ctx_;
-};
-
-}}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <folly/experimental/wangle/channel/ChannelHandler.h>
-#include <folly/experimental/wangle/channel/ChannelPipeline.h>
-#include <folly/experimental/wangle/channel/AsyncSocketHandler.h>
-#include <folly/experimental/wangle/channel/OutputBufferingHandler.h>
-#include <folly/experimental/wangle/channel/test/MockChannelHandler.h>
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-
-using namespace folly;
-using namespace folly::wangle;
-using namespace testing;
-
-typedef StrictMock<MockChannelHandlerAdapter<int, int>> IntHandler;
-typedef ChannelHandlerPtr<IntHandler, false> IntHandlerPtr;
-
-ACTION(FireRead) {
- arg0->fireRead(arg1);
-}
-
-ACTION(FireReadEOF) {
- arg0->fireReadEOF();
-}
-
-ACTION(FireReadException) {
- arg0->fireReadException(arg1);
-}
-
-ACTION(FireWrite) {
- arg0->fireWrite(arg1);
-}
-
-ACTION(FireClose) {
- arg0->fireClose();
-}
-
-// Test move only types, among other things
-TEST(ChannelTest, RealHandlersCompile) {
- EventBase eb;
- auto socket = AsyncSocket::newSocket(&eb);
- // static
- {
- ChannelPipeline<IOBufQueue&, std::unique_ptr<IOBuf>,
- AsyncSocketHandler,
- OutputBufferingHandler>
- pipeline{AsyncSocketHandler(socket), OutputBufferingHandler()};
- EXPECT_TRUE(pipeline.getHandler<AsyncSocketHandler>(0));
- EXPECT_TRUE(pipeline.getHandler<OutputBufferingHandler>(1));
- }
- // dynamic
- {
- ChannelPipeline<IOBufQueue&, std::unique_ptr<IOBuf>> pipeline;
- pipeline
- .addBack(AsyncSocketHandler(socket))
- .addBack(OutputBufferingHandler())
- .finalize();
- EXPECT_TRUE(pipeline.getHandler<AsyncSocketHandler>(0));
- EXPECT_TRUE(pipeline.getHandler<OutputBufferingHandler>(1));
- }
-}
-
-// Test that handlers correctly fire the next handler when directed
-TEST(ChannelTest, FireActions) {
- IntHandler handler1;
- IntHandler handler2;
-
- EXPECT_CALL(handler1, attachPipeline(_));
- EXPECT_CALL(handler2, attachPipeline(_));
-
- ChannelPipeline<int, int, IntHandlerPtr, IntHandlerPtr>
- pipeline(&handler1, &handler2);
-
- EXPECT_CALL(handler1, read_(_, _)).WillOnce(FireRead());
- EXPECT_CALL(handler2, read_(_, _)).Times(1);
- pipeline.read(1);
-
- EXPECT_CALL(handler1, readEOF(_)).WillOnce(FireReadEOF());
- EXPECT_CALL(handler2, readEOF(_)).Times(1);
- pipeline.readEOF();
-
- EXPECT_CALL(handler1, readException(_, _)).WillOnce(FireReadException());
- EXPECT_CALL(handler2, readException(_, _)).Times(1);
- pipeline.readException(make_exception_wrapper<std::runtime_error>("blah"));
-
- EXPECT_CALL(handler2, write_(_, _)).WillOnce(FireWrite());
- EXPECT_CALL(handler1, write_(_, _)).Times(1);
- EXPECT_NO_THROW(pipeline.write(1).value());
-
- EXPECT_CALL(handler2, close_(_)).WillOnce(FireClose());
- EXPECT_CALL(handler1, close_(_)).Times(1);
- EXPECT_NO_THROW(pipeline.close().value());
-
- EXPECT_CALL(handler1, detachPipeline(_));
- EXPECT_CALL(handler2, detachPipeline(_));
-}
-
-// Test that nothing bad happens when actions reach the end of the pipeline
-// (a warning will be logged, however)
-TEST(ChannelTest, ReachEndOfPipeline) {
- IntHandler handler;
- EXPECT_CALL(handler, attachPipeline(_));
- ChannelPipeline<int, int, IntHandlerPtr>
- pipeline(&handler);
-
- EXPECT_CALL(handler, read_(_, _)).WillOnce(FireRead());
- pipeline.read(1);
-
- EXPECT_CALL(handler, readEOF(_)).WillOnce(FireReadEOF());
- pipeline.readEOF();
-
- EXPECT_CALL(handler, readException(_, _)).WillOnce(FireReadException());
- pipeline.readException(make_exception_wrapper<std::runtime_error>("blah"));
-
- EXPECT_CALL(handler, write_(_, _)).WillOnce(FireWrite());
- EXPECT_NO_THROW(pipeline.write(1).value());
-
- EXPECT_CALL(handler, close_(_)).WillOnce(FireClose());
- EXPECT_NO_THROW(pipeline.close().value());
-
- EXPECT_CALL(handler, detachPipeline(_));
-}
-
-// Test having the last read handler turn around and write
-TEST(ChannelTest, TurnAround) {
- IntHandler handler1;
- IntHandler handler2;
-
- EXPECT_CALL(handler1, attachPipeline(_));
- EXPECT_CALL(handler2, attachPipeline(_));
-
- ChannelPipeline<int, int, IntHandlerPtr, IntHandlerPtr>
- pipeline(&handler1, &handler2);
-
- EXPECT_CALL(handler1, read_(_, _)).WillOnce(FireRead());
- EXPECT_CALL(handler2, read_(_, _)).WillOnce(FireWrite());
- EXPECT_CALL(handler1, write_(_, _)).Times(1);
- pipeline.read(1);
-
- EXPECT_CALL(handler1, detachPipeline(_));
- EXPECT_CALL(handler2, detachPipeline(_));
-}
-
-TEST(ChannelTest, DynamicFireActions) {
- IntHandler handler1, handler2, handler3;
- EXPECT_CALL(handler2, attachPipeline(_));
- ChannelPipeline<int, int, IntHandlerPtr>
- pipeline(&handler2);
-
- EXPECT_CALL(handler1, attachPipeline(_));
- EXPECT_CALL(handler3, attachPipeline(_));
-
- pipeline
- .addFront(IntHandlerPtr(&handler1))
- .addBack(IntHandlerPtr(&handler3))
- .finalize();
-
- EXPECT_TRUE(pipeline.getHandler<IntHandlerPtr>(0));
- EXPECT_TRUE(pipeline.getHandler<IntHandlerPtr>(1));
- EXPECT_TRUE(pipeline.getHandler<IntHandlerPtr>(2));
-
- EXPECT_CALL(handler1, read_(_, _)).WillOnce(FireRead());
- EXPECT_CALL(handler2, read_(_, _)).WillOnce(FireRead());
- EXPECT_CALL(handler3, read_(_, _)).Times(1);
- pipeline.read(1);
-
- EXPECT_CALL(handler3, write_(_, _)).WillOnce(FireWrite());
- EXPECT_CALL(handler2, write_(_, _)).WillOnce(FireWrite());
- EXPECT_CALL(handler1, write_(_, _)).Times(1);
- EXPECT_NO_THROW(pipeline.write(1).value());
-
- EXPECT_CALL(handler1, detachPipeline(_));
- EXPECT_CALL(handler2, detachPipeline(_));
- EXPECT_CALL(handler3, detachPipeline(_));
-}
-
-template <class Rin, class Rout = Rin, class Win = Rout, class Wout = Rin>
-class ConcreteChannelHandler : public ChannelHandler<Rin, Rout, Win, Wout> {
- typedef typename ChannelHandler<Rin, Rout, Win, Wout>::Context Context;
- public:
- void read(Context* ctx, Rin msg) {}
- Future<void> write(Context* ctx, Win msg) { return makeFuture(); }
-};
-
-typedef ChannelHandlerAdapter<std::string, std::string> StringHandler;
-typedef ConcreteChannelHandler<int, std::string> IntToStringHandler;
-typedef ConcreteChannelHandler<std::string, int> StringToIntHandler;
-
-TEST(ChannelPipeline, DynamicConstruction) {
- {
- ChannelPipeline<int, int> pipeline;
- EXPECT_THROW(
- pipeline
- .addBack(ChannelHandlerAdapter<std::string, std::string>{})
- .finalize(), std::invalid_argument);
- }
- {
- ChannelPipeline<int, int> pipeline;
- EXPECT_THROW(
- pipeline
- .addFront(ChannelHandlerAdapter<std::string, std::string>{})
- .finalize(),
- std::invalid_argument);
- }
- {
- ChannelPipeline<std::string, std::string, StringHandler, StringHandler>
- pipeline{StringHandler(), StringHandler()};
-
- // Exercise both addFront and addBack. Final pipeline is
- // StI <-> ItS <-> StS <-> StS <-> StI <-> ItS
- EXPECT_NO_THROW(
- pipeline
- .addFront(IntToStringHandler{})
- .addFront(StringToIntHandler{})
- .addBack(StringToIntHandler{})
- .addBack(IntToStringHandler{})
- .finalize());
- }
-}
-
-TEST(ChannelPipeline, AttachTransport) {
- IntHandler handler;
- EXPECT_CALL(handler, attachPipeline(_));
- ChannelPipeline<int, int, IntHandlerPtr>
- pipeline(&handler);
-
- EventBase eb;
- auto socket = AsyncSocket::newSocket(&eb);
-
- EXPECT_CALL(handler, attachTransport(_));
- pipeline.attachTransport(socket);
-
- EXPECT_CALL(handler, detachTransport(_));
- pipeline.detachTransport();
-
- EXPECT_CALL(handler, detachPipeline(_));
-}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <folly/experimental/wangle/channel/ChannelHandler.h>
-#include <gmock/gmock.h>
-
-namespace folly { namespace wangle {
-
-template <class Rin, class Rout = Rin, class Win = Rout, class Wout = Rin>
-class MockChannelHandler : public ChannelHandler<Rin, Rout, Win, Wout> {
- public:
- typedef typename ChannelHandler<Rin, Rout, Win, Wout>::Context Context;
-
- MockChannelHandler() = default;
- MockChannelHandler(MockChannelHandler&&) = default;
-
- MOCK_METHOD2_T(read_, void(Context*, Rin&));
- MOCK_METHOD1_T(readEOF, void(Context*));
- MOCK_METHOD2_T(readException, void(Context*, exception_wrapper));
-
- MOCK_METHOD2_T(write_, void(Context*, Win&));
- MOCK_METHOD1_T(close_, void(Context*));
-
- MOCK_METHOD1_T(attachPipeline, void(Context*));
- MOCK_METHOD1_T(attachTransport, void(Context*));
- MOCK_METHOD1_T(detachPipeline, void(Context*));
- MOCK_METHOD1_T(detachTransport, void(Context*));
-
- void read(Context* ctx, Rin msg) {
- read_(ctx, msg);
- }
-
- Future<void> write(Context* ctx, Win msg) override {
- return makeFutureTry([&](){
- write_(ctx, msg);
- });
- }
-
- Future<void> close(Context* ctx) override {
- return makeFutureTry([&](){
- close_(ctx);
- });
- }
-};
-
-template <class R, class W = R>
-using MockChannelHandlerAdapter = MockChannelHandler<R, R, W, W>;
-
-}}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <folly/experimental/wangle/channel/ChannelPipeline.h>
-#include <folly/experimental/wangle/channel/OutputBufferingHandler.h>
-#include <folly/experimental/wangle/channel/test/MockChannelHandler.h>
-#include <folly/io/async/AsyncSocket.h>
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-
-using namespace folly;
-using namespace folly::wangle;
-using namespace testing;
-
-typedef StrictMock<MockChannelHandlerAdapter<
- IOBufQueue&,
- std::unique_ptr<IOBuf>>>
-MockHandler;
-
-MATCHER_P(IOBufContains, str, "") { return arg->moveToFbString() == str; }
-
-TEST(OutputBufferingHandlerTest, Basic) {
- MockHandler mockHandler;
- EXPECT_CALL(mockHandler, attachPipeline(_));
- ChannelPipeline<IOBufQueue&, std::unique_ptr<IOBuf>,
- ChannelHandlerPtr<MockHandler, false>,
- OutputBufferingHandler>
- pipeline(&mockHandler, OutputBufferingHandler{});
-
- EventBase eb;
- auto socket = AsyncSocket::newSocket(&eb);
- EXPECT_CALL(mockHandler, attachTransport(_));
- pipeline.attachTransport(socket);
-
- // Buffering should prevent writes until the EB loops, and the writes should
- // be batched into one write call.
- auto f1 = pipeline.write(IOBuf::copyBuffer("hello"));
- auto f2 = pipeline.write(IOBuf::copyBuffer("world"));
- EXPECT_FALSE(f1.isReady());
- EXPECT_FALSE(f2.isReady());
- EXPECT_CALL(mockHandler, write_(_, IOBufContains("helloworld")));
- eb.loopOnce();
- EXPECT_TRUE(f1.isReady());
- EXPECT_TRUE(f2.isReady());
- EXPECT_CALL(mockHandler, detachPipeline(_));
-}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <glog/logging.h>
-
-namespace folly { namespace wangle {
-
-template <class T>
-class BlockingQueue {
- public:
- virtual ~BlockingQueue() {}
- virtual void add(T item) = 0;
- virtual void addWithPriority(T item, uint32_t priority) {
- LOG_FIRST_N(WARNING, 1) <<
- "add(item, priority) called on a non-priority queue";
- add(std::move(item));
- }
- virtual uint32_t getNumPriorities() {
- LOG_FIRST_N(WARNING, 1) <<
- "getNumPriorities() called on a non-priority queue";
- return 1;
- }
- virtual T take() = 0;
- virtual size_t size() = 0;
-};
-
-}} // folly::wangle
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <folly/experimental/wangle/concurrent/CPUThreadPoolExecutor.h>
-#include <folly/experimental/wangle/concurrent/PriorityLifoSemMPMCQueue.h>
-
-namespace folly { namespace wangle {
-
-const size_t CPUThreadPoolExecutor::kDefaultMaxQueueSize = 1 << 18;
-const size_t CPUThreadPoolExecutor::kDefaultNumPriorities = 2;
-
-CPUThreadPoolExecutor::CPUThreadPoolExecutor(
- size_t numThreads,
- std::unique_ptr<BlockingQueue<CPUTask>> taskQueue,
- std::shared_ptr<ThreadFactory> threadFactory)
- : ThreadPoolExecutor(numThreads, std::move(threadFactory)),
- taskQueue_(std::move(taskQueue)) {
- addThreads(numThreads);
- CHECK(threadList_.get().size() == numThreads);
-}
-
-CPUThreadPoolExecutor::CPUThreadPoolExecutor(
- size_t numThreads,
- std::shared_ptr<ThreadFactory> threadFactory)
- : CPUThreadPoolExecutor(
- numThreads,
- folly::make_unique<LifoSemMPMCQueue<CPUTask>>(
- CPUThreadPoolExecutor::kDefaultMaxQueueSize),
- std::move(threadFactory)) {}
-
-CPUThreadPoolExecutor::CPUThreadPoolExecutor(size_t numThreads)
- : CPUThreadPoolExecutor(
- numThreads,
- std::make_shared<NamedThreadFactory>("CPUThreadPool")) {}
-
-CPUThreadPoolExecutor::CPUThreadPoolExecutor(
- size_t numThreads,
- uint32_t numPriorities,
- std::shared_ptr<ThreadFactory> threadFactory)
- : CPUThreadPoolExecutor(
- numThreads,
- folly::make_unique<PriorityLifoSemMPMCQueue<CPUTask>>(
- numPriorities,
- CPUThreadPoolExecutor::kDefaultMaxQueueSize),
- std::move(threadFactory)) {}
-
-CPUThreadPoolExecutor::~CPUThreadPoolExecutor() {
- stop();
- CHECK(threadsToStop_ == 0);
-}
-
-void CPUThreadPoolExecutor::add(Func func) {
- add(std::move(func), std::chrono::milliseconds(0));
-}
-
-void CPUThreadPoolExecutor::add(
- Func func,
- std::chrono::milliseconds expiration,
- Func expireCallback) {
- // TODO handle enqueue failure, here and in other add() callsites
- taskQueue_->add(
- CPUTask(std::move(func), expiration, std::move(expireCallback)));
-}
-
-void CPUThreadPoolExecutor::add(Func func, uint32_t priority) {
- add(std::move(func), priority, std::chrono::milliseconds(0));
-}
-
-void CPUThreadPoolExecutor::add(
- Func func,
- uint32_t priority,
- std::chrono::milliseconds expiration,
- Func expireCallback) {
- CHECK(priority < getNumPriorities());
- taskQueue_->addWithPriority(
- CPUTask(std::move(func), expiration, std::move(expireCallback)),
- priority);
-}
-
-uint32_t CPUThreadPoolExecutor::getNumPriorities() const {
- return taskQueue_->getNumPriorities();
-}
-
-BlockingQueue<CPUThreadPoolExecutor::CPUTask>*
-CPUThreadPoolExecutor::getTaskQueue() {
- return taskQueue_.get();
-}
-
-void CPUThreadPoolExecutor::threadRun(std::shared_ptr<Thread> thread) {
- thread->startupBaton.post();
- while (1) {
- auto task = taskQueue_->take();
- if (UNLIKELY(task.poison)) {
- CHECK(threadsToStop_-- > 0);
- stoppedThreads_.add(thread);
- return;
- } else {
- runTask(thread, std::move(task));
- }
-
- if (UNLIKELY(threadsToStop_ > 0 && !isJoin_)) {
- if (--threadsToStop_ >= 0) {
- stoppedThreads_.add(thread);
- return;
- } else {
- threadsToStop_++;
- }
- }
- }
-}
-
-void CPUThreadPoolExecutor::stopThreads(size_t n) {
- CHECK(stoppedThreads_.size() == 0);
- threadsToStop_ = n;
- for (size_t i = 0; i < n; i++) {
- taskQueue_->add(CPUTask());
- }
-}
-
-uint64_t CPUThreadPoolExecutor::getPendingTaskCount() {
- return taskQueue_->size();
-}
-
-}} // folly::wangle
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <folly/experimental/wangle/concurrent/ThreadPoolExecutor.h>
-
-namespace folly { namespace wangle {
-
-class CPUThreadPoolExecutor : public ThreadPoolExecutor {
- public:
- struct CPUTask;
-
- explicit CPUThreadPoolExecutor(
- size_t numThreads,
- std::unique_ptr<BlockingQueue<CPUTask>> taskQueue,
- std::shared_ptr<ThreadFactory> threadFactory =
- std::make_shared<NamedThreadFactory>("CPUThreadPool"));
-
- explicit CPUThreadPoolExecutor(size_t numThreads);
-
- explicit CPUThreadPoolExecutor(
- size_t numThreads,
- std::shared_ptr<ThreadFactory> threadFactory);
-
- explicit CPUThreadPoolExecutor(
- size_t numThreads,
- uint32_t numPriorities,
- std::shared_ptr<ThreadFactory> threadFactory =
- std::make_shared<NamedThreadFactory>("CPUThreadPool"));
-
- ~CPUThreadPoolExecutor();
-
- void add(Func func) override;
- void add(
- Func func,
- std::chrono::milliseconds expiration,
- Func expireCallback = nullptr) override;
-
- void add(Func func, uint32_t priority);
- void add(
- Func func,
- uint32_t priority,
- std::chrono::milliseconds expiration,
- Func expireCallback = nullptr);
-
- uint32_t getNumPriorities() const;
-
- struct CPUTask : public ThreadPoolExecutor::Task {
- // Must be noexcept move constructible so it can be used in MPMCQueue
- explicit CPUTask(
- Func&& f,
- std::chrono::milliseconds expiration,
- Func&& expireCallback)
- : Task(std::move(f), expiration, std::move(expireCallback)),
- poison(false) {}
- CPUTask()
- : Task(nullptr, std::chrono::milliseconds(0), nullptr),
- poison(true) {}
- CPUTask(CPUTask&& o) noexcept : Task(std::move(o)), poison(o.poison) {}
- CPUTask(const CPUTask&) = default;
- CPUTask& operator=(const CPUTask&) = default;
- bool poison;
- };
-
- static const size_t kDefaultMaxQueueSize;
- static const size_t kDefaultNumPriorities;
-
- protected:
- BlockingQueue<CPUTask>* getTaskQueue();
-
- private:
- void threadRun(ThreadPtr thread) override;
- void stopThreads(size_t n) override;
- uint64_t getPendingTaskCount() override;
-
- std::unique_ptr<BlockingQueue<CPUTask>> taskQueue_;
- std::atomic<ssize_t> threadsToStop_{0};
-};
-
-}} // folly::wangle
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <folly/experimental/wangle/concurrent/Codel.h>
-#include <algorithm>
-#include <math.h>
-
-#ifndef NO_LIB_GFLAGS
- #include <gflags/gflags.h>
- DEFINE_int32(codel_interval, 100,
- "Codel default interval time in ms");
- DEFINE_int32(codel_target_delay, 5,
- "Target codel queueing delay in ms");
-#endif
-
-namespace folly { namespace wangle {
-
-#ifdef NO_LIB_GFLAGS
- int32_t FLAGS_codel_interval = 100;
- int32_t FLAGS_codel_target_delay = 5;
-#endif
-
-Codel::Codel()
- : codelMinDelay_(0),
- codelIntervalTime_(std::chrono::steady_clock::now()),
- codelResetDelay_(true),
- overloaded_(false) {}
-
-bool Codel::overloaded(std::chrono::microseconds delay) {
- bool ret = false;
- auto now = std::chrono::steady_clock::now();
-
- // Avoid another thread updating the value at the same time we are using it
- // to calculate the overloaded state
- auto minDelay = codelMinDelay_;
-
- if (now > codelIntervalTime_ &&
- (!codelResetDelay_.load(std::memory_order_acquire)
- && !codelResetDelay_.exchange(true))) {
- codelIntervalTime_ = now + std::chrono::milliseconds(FLAGS_codel_interval);
-
- if (minDelay > std::chrono::milliseconds(FLAGS_codel_target_delay)) {
- overloaded_ = true;
- } else {
- overloaded_ = false;
- }
- }
- // Care must be taken that only a single thread resets codelMinDelay_,
- // and that it happens after the interval reset above
- if (codelResetDelay_.load(std::memory_order_acquire) &&
- codelResetDelay_.exchange(false)) {
- codelMinDelay_ = delay;
- // More than one request must come in during an interval before codel
- // starts dropping requests
- return false;
- } else if(delay < codelMinDelay_) {
- codelMinDelay_ = delay;
- }
-
- if (overloaded_ &&
- delay > std::chrono::milliseconds(FLAGS_codel_target_delay * 2)) {
- ret = true;
- }
-
- return ret;
-
-}
-
-int Codel::getLoad() {
- return std::min(100, (int)codelMinDelay_.count() /
- (2 * FLAGS_codel_target_delay));
-}
-
-int Codel::getMinDelay() {
- return (int) codelMinDelay_.count();
-}
-
-}} //namespace
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <atomic>
-#include <chrono>
-
-namespace folly { namespace wangle {
-
-/* Codel algorithm implementation:
- * http://en.wikipedia.org/wiki/CoDel
- *
- * Algorithm modified slightly: Instead of changing the interval time
- * based on the average min delay, instead we use an alternate timeout
- * for each task if the min delay during the interval period is too
- * high.
- *
- * This was found to have better latency metrics than changing the
- * window size, since we can communicate with the sender via thrift
- * instead of only via the tcp window size congestion control, as in TCP.
- */
-class Codel {
-
- public:
- Codel();
-
- // Given a delay, returns wether the codel algorithm would
- // reject a queued request with this delay.
- //
- // Internally, it also keeps track of the interval
- bool overloaded(std::chrono::microseconds delay);
-
- // Get the queue load, as seen by the codel algorithm
- // Gives a rough guess at how bad the queue delay is.
- //
- // Return: 0 = no delay, 100 = At the queueing limit
- int getLoad();
-
- int getMinDelay();
-
- private:
- std::chrono::microseconds codelMinDelay_;
- std::chrono::time_point<std::chrono::steady_clock> codelIntervalTime_;
-
- // flag to make overloaded() thread-safe, since we only want
- // to reset the delay once per time period
- std::atomic<bool> codelResetDelay_;
-
- bool overloaded_;
-};
-
-}} // Namespace
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-#include <folly/wangle/futures/Future.h>
-
-namespace folly { namespace wangle {
-
-template <typename ExecutorImpl>
-class FutureExecutor : public ExecutorImpl {
- public:
- template <typename... Args>
- explicit FutureExecutor(Args&&... args)
- : ExecutorImpl(std::forward<Args>(args)...) {}
-
- /*
- * Given a function func that returns a Future<T>, adds that function to the
- * contained Executor and returns a Future<T> which will be fulfilled with
- * func's result once it has been executed.
- *
- * For example: auto f = futureExecutor.addFuture([](){
- * return doAsyncWorkAndReturnAFuture();
- * });
- */
- template <typename F>
- typename std::enable_if<isFuture<typename std::result_of<F()>::type>::value,
- typename std::result_of<F()>::type>::type
- addFuture(F func) {
- typedef typename std::result_of<F()>::type::value_type T;
- Promise<T> promise;
- auto future = promise.getFuture();
- auto movePromise = folly::makeMoveWrapper(std::move(promise));
- auto moveFunc = folly::makeMoveWrapper(std::move(func));
- ExecutorImpl::add([movePromise, moveFunc] () mutable {
- (*moveFunc)().then([movePromise] (Try<T>&& t) mutable {
- movePromise->fulfilTry(std::move(t));
- });
- });
- return future;
- }
-
- /*
- * Similar to addFuture above, but takes a func that returns some non-Future
- * type T.
- *
- * For example: auto f = futureExecutor.addFuture([]() {
- * return 42;
- * });
- */
- template <typename F>
- typename std::enable_if<!isFuture<typename std::result_of<F()>::type>::value,
- Future<typename std::result_of<F()>::type>>::type
- addFuture(F func) {
- typedef typename std::result_of<F()>::type T;
- Promise<T> promise;
- auto future = promise.getFuture();
- auto movePromise = folly::makeMoveWrapper(std::move(promise));
- auto moveFunc = folly::makeMoveWrapper(std::move(func));
- ExecutorImpl::add([movePromise, moveFunc] () mutable {
- movePromise->fulfil(std::move(*moveFunc));
- });
- return future;
- }
-};
-
-}}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <folly/experimental/Singleton.h>
-#include <folly/experimental/wangle/concurrent/IOExecutor.h>
-#include <folly/experimental/wangle/concurrent/IOThreadPoolExecutor.h>
-
-using namespace folly;
-using namespace folly::wangle;
-
-namespace {
-
-Singleton<IOThreadPoolExecutor> globalIOThreadPoolSingleton(
- "GlobalIOThreadPool",
- [](){
- return new IOThreadPoolExecutor(
- sysconf(_SC_NPROCESSORS_ONLN),
- std::make_shared<NamedThreadFactory>("GlobalIOThreadPool"));
- });
-
-}
-
-namespace folly { namespace wangle {
-
-IOExecutor* getIOExecutor() {
- auto singleton = IOExecutor::getSingleton();
- auto executor = singleton->load();
- while (!executor) {
- IOExecutor* nullIOExecutor = nullptr;
- singleton->compare_exchange_strong(
- nullIOExecutor,
- Singleton<IOThreadPoolExecutor>::get("GlobalIOThreadPool"));
- executor = singleton->load();
- }
- return executor;
-}
-
-void setIOExecutor(IOExecutor* executor) {
- IOExecutor::getSingleton()->store(executor);
-}
-
-}} // folly::wangle
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-namespace folly { namespace wangle {
-
-class IOExecutor;
-
-// Retrieve the global IOExecutor. If there is none, a default
-// IOThreadPoolExecutor will be constructed and returned.
-IOExecutor* getIOExecutor();
-
-// Set an IOExecutor to be the global IOExecutor which will be returned by
-// subsequent calls to getIOExecutor(). IOExecutors will uninstall themselves
-// as global when they are destructed.
-void setIOExecutor(IOExecutor* executor);
-
-}}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#include <folly/experimental/wangle/concurrent/IOExecutor.h>
-
-#include <folly/experimental/Singleton.h>
-#include <folly/experimental/wangle/concurrent/GlobalExecutor.h>
-
-using folly::Singleton;
-using folly::wangle::IOExecutor;
-
-namespace {
-
-Singleton<std::atomic<IOExecutor*>> globalIOExecutorSingleton(
- "GlobalIOExecutor",
- [](){
- return new std::atomic<IOExecutor*>(nullptr);
- });
-
-}
-
-namespace folly { namespace wangle {
-
-IOExecutor::~IOExecutor() {
- auto thisCopy = this;
- try {
- getSingleton()->compare_exchange_strong(thisCopy, nullptr);
- } catch (const std::runtime_error& e) {
- // The global IOExecutor singleton was already destructed so doesn't need to
- // be restored. Ignore.
- }
-}
-
-std::atomic<IOExecutor*>* IOExecutor::getSingleton() {
- return Singleton<std::atomic<IOExecutor*>>::get("GlobalIOExecutor");
-}
-
-}} // folly::wangle
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <atomic>
-#include <folly/Executor.h>
-
-namespace folly {
-class EventBase;
-}
-
-namespace folly { namespace wangle {
-
-// An IOExecutor is an executor that operates on at least one EventBase. One of
-// these EventBases should be accessible via getEventBase(). The event base
-// returned by a call to getEventBase() is implementation dependent.
-//
-// Note that IOExecutors don't necessarily loop on the base themselves - for
-// instance, EventBase itself is an IOExecutor but doesn't drive itself.
-//
-// Implementations of IOExecutor are eligible to become the global IO executor,
-// returned on every call to getIOExecutor(), via setIOExecutor().
-// These functions are declared in GlobalExecutor.h
-//
-// If getIOExecutor is called and none has been set, a default global
-// IOThreadPoolExecutor will be created and returned.
-class IOExecutor : public virtual Executor {
- public:
- virtual ~IOExecutor();
- virtual EventBase* getEventBase() = 0;
-
- private:
- static std::atomic<IOExecutor*>* getSingleton();
- friend IOExecutor* getIOExecutor();
- friend void setIOExecutor(IOExecutor* executor);
-};
-
-}}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <folly/experimental/wangle/concurrent/IOThreadPoolExecutor.h>
-
-#include <folly/MoveWrapper.h>
-#include <glog/logging.h>
-#include <folly/io/async/EventBaseManager.h>
-
-#include <folly/detail/MemoryIdler.h>
-
-namespace folly { namespace wangle {
-
-using folly::detail::MemoryIdler;
-
-/* Class that will free jemalloc caches and madvise the stack away
- * if the event loop is unused for some period of time
- */
-class MemoryIdlerTimeout
- : public AsyncTimeout , public EventBase::LoopCallback {
- public:
- explicit MemoryIdlerTimeout(EventBase* b) : AsyncTimeout(b), base_(b) {}
-
- virtual void timeoutExpired() noexcept {
- idled = true;
- }
-
- virtual void runLoopCallback() noexcept {
- if (idled) {
- MemoryIdler::flushLocalMallocCaches();
- MemoryIdler::unmapUnusedStack(MemoryIdler::kDefaultStackToRetain);
-
- idled = false;
- } else {
- std::chrono::steady_clock::duration idleTimeout =
- MemoryIdler::defaultIdleTimeout.load(
- std::memory_order_acquire);
-
- idleTimeout = MemoryIdler::getVariationTimeout(idleTimeout);
-
- scheduleTimeout(std::chrono::duration_cast<std::chrono::milliseconds>(
- idleTimeout).count());
- }
-
- // reschedule this callback for the next event loop.
- base_->runBeforeLoop(this);
- }
- private:
- EventBase* base_;
- bool idled{false};
-} ;
-
-IOThreadPoolExecutor::IOThreadPoolExecutor(
- size_t numThreads,
- std::shared_ptr<ThreadFactory> threadFactory)
- : ThreadPoolExecutor(numThreads, std::move(threadFactory)),
- nextThread_(0) {
- addThreads(numThreads);
- CHECK(threadList_.get().size() == numThreads);
-}
-
-IOThreadPoolExecutor::~IOThreadPoolExecutor() {
- stop();
-}
-
-void IOThreadPoolExecutor::add(Func func) {
- add(std::move(func), std::chrono::milliseconds(0));
-}
-
-void IOThreadPoolExecutor::add(
- Func func,
- std::chrono::milliseconds expiration,
- Func expireCallback) {
- RWSpinLock::ReadHolder{&threadListLock_};
- if (threadList_.get().empty()) {
- throw std::runtime_error("No threads available");
- }
- auto ioThread = pickThread();
-
- auto moveTask = folly::makeMoveWrapper(
- Task(std::move(func), expiration, std::move(expireCallback)));
- auto wrappedFunc = [ioThread, moveTask] () mutable {
- runTask(ioThread, std::move(*moveTask));
- ioThread->pendingTasks--;
- };
-
- ioThread->pendingTasks++;
- if (!ioThread->eventBase->runInEventBaseThread(std::move(wrappedFunc))) {
- ioThread->pendingTasks--;
- throw std::runtime_error("Unable to run func in event base thread");
- }
-}
-
-std::shared_ptr<IOThreadPoolExecutor::IOThread>
-IOThreadPoolExecutor::pickThread() {
- if (*thisThread_) {
- return *thisThread_;
- }
- auto thread = threadList_.get()[nextThread_++ % threadList_.get().size()];
- return std::static_pointer_cast<IOThread>(thread);
-}
-
-EventBase* IOThreadPoolExecutor::getEventBase() {
- return pickThread()->eventBase;
-}
-
-std::shared_ptr<ThreadPoolExecutor::Thread>
-IOThreadPoolExecutor::makeThread() {
- return std::make_shared<IOThread>(this);
-}
-
-void IOThreadPoolExecutor::threadRun(ThreadPtr thread) {
- const auto ioThread = std::static_pointer_cast<IOThread>(thread);
- ioThread->eventBase =
- folly::EventBaseManager::get()->getEventBase();
- thisThread_.reset(new std::shared_ptr<IOThread>(ioThread));
-
- auto idler = new MemoryIdlerTimeout(ioThread->eventBase);
- ioThread->eventBase->runBeforeLoop(idler);
-
- thread->startupBaton.post();
- while (ioThread->shouldRun) {
- ioThread->eventBase->loopForever();
- }
- if (isJoin_) {
- while (ioThread->pendingTasks > 0) {
- ioThread->eventBase->loopOnce();
- }
- }
- stoppedThreads_.add(ioThread);
-}
-
-// threadListLock_ is writelocked
-void IOThreadPoolExecutor::stopThreads(size_t n) {
- for (size_t i = 0; i < n; i++) {
- const auto ioThread = std::static_pointer_cast<IOThread>(
- threadList_.get()[i]);
- ioThread->shouldRun = false;
- ioThread->eventBase->terminateLoopSoon();
- }
-}
-
-std::vector<EventBase*> IOThreadPoolExecutor::getEventBases() {
- std::vector<EventBase*> bases;
- RWSpinLock::ReadHolder{&threadListLock_};
- for (const auto& thread : threadList_.get()) {
- auto ioThread = std::static_pointer_cast<IOThread>(thread);
- bases.push_back(ioThread->eventBase);
- }
- return bases;
-}
-
-// threadListLock_ is readlocked
-uint64_t IOThreadPoolExecutor::getPendingTaskCount() {
- uint64_t count = 0;
- for (const auto& thread : threadList_.get()) {
- auto ioThread = std::static_pointer_cast<IOThread>(thread);
- size_t pendingTasks = ioThread->pendingTasks;
- if (pendingTasks > 0 && !ioThread->idle) {
- pendingTasks--;
- }
- count += pendingTasks;
- }
- return count;
-}
-
-}} // folly::wangle
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <folly/experimental/wangle/concurrent/IOExecutor.h>
-#include <folly/experimental/wangle/concurrent/ThreadPoolExecutor.h>
-#include <folly/io/async/EventBase.h>
-
-namespace folly { namespace wangle {
-
-// N.B. For this thread pool, stop() behaves like join() because outstanding
-// tasks belong to the event base and will be executed upon its destruction.
-class IOThreadPoolExecutor : public ThreadPoolExecutor, public IOExecutor {
- public:
- explicit IOThreadPoolExecutor(
- size_t numThreads,
- std::shared_ptr<ThreadFactory> threadFactory =
- std::make_shared<NamedThreadFactory>("IOThreadPool"));
-
- ~IOThreadPoolExecutor();
-
- void add(Func func) override;
- void add(
- Func func,
- std::chrono::milliseconds expiration,
- Func expireCallback = nullptr) override;
-
- EventBase* getEventBase() override;
-
- std::vector<EventBase*> getEventBases();
-
- private:
- struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING IOThread : public Thread {
- IOThread(IOThreadPoolExecutor* pool)
- : Thread(pool),
- shouldRun(true),
- pendingTasks(0) {};
- std::atomic<bool> shouldRun;
- std::atomic<size_t> pendingTasks;
- EventBase* eventBase;
- };
-
- ThreadPtr makeThread() override;
- std::shared_ptr<IOThread> pickThread();
- void threadRun(ThreadPtr thread) override;
- void stopThreads(size_t n) override;
- uint64_t getPendingTaskCount() override;
-
- size_t nextThread_;
- ThreadLocal<std::shared_ptr<IOThread>> thisThread_;
-};
-
-}} // folly::wangle
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-#include <folly/experimental/wangle/concurrent/BlockingQueue.h>
-#include <folly/LifoSem.h>
-#include <folly/MPMCQueue.h>
-
-namespace folly { namespace wangle {
-
-template <class T>
-class LifoSemMPMCQueue : public BlockingQueue<T> {
- public:
- explicit LifoSemMPMCQueue(size_t max_capacity) : queue_(max_capacity) {}
-
- void add(T item) override {
- if (!queue_.write(std::move(item))) {
- throw std::runtime_error("LifoSemMPMCQueue full, can't add item");
- }
- sem_.post();
- }
-
- T take() override {
- T item;
- while (!queue_.read(item)) {
- sem_.wait();
- }
- return item;
- }
-
- size_t capacity() {
- return queue_.capacity();
- }
-
- size_t size() override {
- return queue_.size();
- }
-
- private:
- LifoSem sem_;
- MPMCQueue<T> queue_;
-};
-
-}} // folly::wangle
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <atomic>
-#include <string>
-#include <thread>
-
-#include <folly/experimental/wangle/concurrent/ThreadFactory.h>
-#include <folly/Conv.h>
-#include <folly/Range.h>
-#include <folly/ThreadName.h>
-
-namespace folly { namespace wangle {
-
-class NamedThreadFactory : public ThreadFactory {
- public:
- explicit NamedThreadFactory(folly::StringPiece prefix)
- : prefix_(prefix.str()), suffix_(0) {}
-
- std::thread newThread(Func&& func) override {
- auto thread = std::thread(std::move(func));
- folly::setThreadName(
- thread.native_handle(),
- folly::to<std::string>(prefix_, suffix_++));
- return thread;
- }
-
- void setNamePrefix(folly::StringPiece prefix) {
- prefix_ = prefix.str();
- }
-
- std::string getNamePrefix() {
- return prefix_;
- }
-
- private:
- std::string prefix_;
- std::atomic<uint64_t> suffix_;
-};
-
-}} // folly::wangle
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-#include <folly/experimental/wangle/concurrent/BlockingQueue.h>
-#include <folly/LifoSem.h>
-#include <folly/MPMCQueue.h>
-
-namespace folly { namespace wangle {
-
-template <class T>
-class PriorityLifoSemMPMCQueue : public BlockingQueue<T> {
- public:
- explicit PriorityLifoSemMPMCQueue(uint32_t numPriorities, size_t capacity) {
- CHECK(numPriorities > 0);
- queues_.reserve(numPriorities);
- for (int i = 0; i < numPriorities; i++) {
- queues_.push_back(MPMCQueue<T>(capacity));
- }
- }
-
- uint32_t getNumPriorities() override {
- return queues_.size();
- }
-
- // Add at lowest priority by default
- void add(T item) override {
- addWithPriority(std::move(item), 0);
- }
-
- void addWithPriority(T item, uint32_t priority) override {
- CHECK(priority < queues_.size());
- if (!queues_[priority].write(std::move(item))) {
- throw std::runtime_error("LifoSemMPMCQueue full, can't add item");
- }
- sem_.post();
- }
-
- T take() override {
- T item;
- while (true) {
- for (auto it = queues_.rbegin(); it != queues_.rend(); it++) {
- if (it->read(item)) {
- return item;
- }
- }
- sem_.wait();
- }
- }
-
- size_t size() override {
- size_t size = 0;
- for (auto& q : queues_) {
- size += q.size();
- }
- return size;
- }
-
- private:
- LifoSem sem_;
- std::vector<MPMCQueue<T>> queues_;
-};
-
-}} // folly::wangle
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-#include <folly/Executor.h>
-
-#include <thread>
-
-namespace folly { namespace wangle {
-
-class ThreadFactory {
- public:
- virtual ~ThreadFactory() {}
- virtual std::thread newThread(Func&& func) = 0;
-};
-
-}} // folly::wangle
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <folly/experimental/wangle/concurrent/ThreadPoolExecutor.h>
-
-namespace folly { namespace wangle {
-
-ThreadPoolExecutor::ThreadPoolExecutor(
- size_t numThreads,
- std::shared_ptr<ThreadFactory> threadFactory)
- : threadFactory_(std::move(threadFactory)),
- taskStatsSubject_(std::make_shared<Subject<TaskStats>>()) {}
-
-ThreadPoolExecutor::~ThreadPoolExecutor() {
- CHECK(threadList_.get().size() == 0);
-}
-
-ThreadPoolExecutor::Task::Task(
- Func&& func,
- std::chrono::milliseconds expiration,
- Func&& expireCallback)
- : func_(std::move(func)),
- expiration_(expiration),
- expireCallback_(std::move(expireCallback)) {
- // Assume that the task in enqueued on creation
- enqueueTime_ = std::chrono::steady_clock::now();
-}
-
-void ThreadPoolExecutor::runTask(
- const ThreadPtr& thread,
- Task&& task) {
- thread->idle = false;
- auto startTime = std::chrono::steady_clock::now();
- task.stats_.waitTime = startTime - task.enqueueTime_;
- if (task.expiration_ > std::chrono::milliseconds(0) &&
- task.stats_.waitTime >= task.expiration_) {
- task.stats_.expired = true;
- if (task.expireCallback_ != nullptr) {
- task.expireCallback_();
- }
- } else {
- try {
- task.func_();
- } catch (const std::exception& e) {
- LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled " <<
- typeid(e).name() << " exception: " << e.what();
- } catch (...) {
- LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled non-exception "
- "object";
- }
- task.stats_.runTime = std::chrono::steady_clock::now() - startTime;
- }
- thread->idle = true;
- thread->taskStatsSubject->onNext(std::move(task.stats_));
-}
-
-size_t ThreadPoolExecutor::numThreads() {
- RWSpinLock::ReadHolder{&threadListLock_};
- return threadList_.get().size();
-}
-
-void ThreadPoolExecutor::setNumThreads(size_t n) {
- RWSpinLock::WriteHolder{&threadListLock_};
- const auto current = threadList_.get().size();
- if (n > current ) {
- addThreads(n - current);
- } else if (n < current) {
- removeThreads(current - n, true);
- }
- CHECK(threadList_.get().size() == n);
-}
-
-// threadListLock_ is writelocked
-void ThreadPoolExecutor::addThreads(size_t n) {
- std::vector<ThreadPtr> newThreads;
- for (size_t i = 0; i < n; i++) {
- newThreads.push_back(makeThread());
- }
- for (auto& thread : newThreads) {
- // TODO need a notion of failing to create the thread
- // and then handling for that case
- thread->handle = threadFactory_->newThread(
- std::bind(&ThreadPoolExecutor::threadRun, this, thread));
- threadList_.add(thread);
- }
- for (auto& thread : newThreads) {
- thread->startupBaton.wait();
- }
-}
-
-// threadListLock_ is writelocked
-void ThreadPoolExecutor::removeThreads(size_t n, bool isJoin) {
- CHECK(n <= threadList_.get().size());
- CHECK(stoppedThreads_.size() == 0);
- isJoin_ = isJoin;
- stopThreads(n);
- for (size_t i = 0; i < n; i++) {
- auto thread = stoppedThreads_.take();
- thread->handle.join();
- threadList_.remove(thread);
- }
- CHECK(stoppedThreads_.size() == 0);
-}
-
-void ThreadPoolExecutor::stop() {
- RWSpinLock::WriteHolder{&threadListLock_};
- removeThreads(threadList_.get().size(), false);
- CHECK(threadList_.get().size() == 0);
-}
-
-void ThreadPoolExecutor::join() {
- RWSpinLock::WriteHolder{&threadListLock_};
- removeThreads(threadList_.get().size(), true);
- CHECK(threadList_.get().size() == 0);
-}
-
-ThreadPoolExecutor::PoolStats ThreadPoolExecutor::getPoolStats() {
- RWSpinLock::ReadHolder{&threadListLock_};
- ThreadPoolExecutor::PoolStats stats;
- stats.threadCount = threadList_.get().size();
- for (auto thread : threadList_.get()) {
- if (thread->idle) {
- stats.idleThreadCount++;
- } else {
- stats.activeThreadCount++;
- }
- }
- stats.pendingTaskCount = getPendingTaskCount();
- stats.totalTaskCount = stats.pendingTaskCount + stats.activeThreadCount;
- return stats;
-}
-
-std::atomic<uint64_t> ThreadPoolExecutor::Thread::nextId(0);
-
-void ThreadPoolExecutor::StoppedThreadQueue::add(
- ThreadPoolExecutor::ThreadPtr item) {
- std::lock_guard<std::mutex> guard(mutex_);
- queue_.push(std::move(item));
- sem_.post();
-}
-
-ThreadPoolExecutor::ThreadPtr ThreadPoolExecutor::StoppedThreadQueue::take() {
- while(1) {
- {
- std::lock_guard<std::mutex> guard(mutex_);
- if (queue_.size() > 0) {
- auto item = std::move(queue_.front());
- queue_.pop();
- return item;
- }
- }
- sem_.wait();
- }
-}
-
-size_t ThreadPoolExecutor::StoppedThreadQueue::size() {
- std::lock_guard<std::mutex> guard(mutex_);
- return queue_.size();
-}
-
-}} // folly::wangle
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-#include <folly/Executor.h>
-#include <folly/experimental/wangle/concurrent/LifoSemMPMCQueue.h>
-#include <folly/experimental/wangle/concurrent/NamedThreadFactory.h>
-#include <folly/experimental/wangle/rx/Observable.h>
-#include <folly/Baton.h>
-#include <folly/Memory.h>
-#include <folly/RWSpinLock.h>
-
-#include <algorithm>
-#include <mutex>
-#include <queue>
-
-#include <glog/logging.h>
-
-namespace folly { namespace wangle {
-
-class ThreadPoolExecutor : public virtual Executor {
- public:
- explicit ThreadPoolExecutor(
- size_t numThreads,
- std::shared_ptr<ThreadFactory> threadFactory);
-
- ~ThreadPoolExecutor();
-
- virtual void add(Func func) override = 0;
- virtual void add(
- Func func,
- std::chrono::milliseconds expiration,
- Func expireCallback) = 0;
-
- void setThreadFactory(std::shared_ptr<ThreadFactory> threadFactory) {
- CHECK(numThreads() == 0);
- threadFactory_ = std::move(threadFactory);
- }
-
- std::shared_ptr<ThreadFactory> getThreadFactory(void) {
- return threadFactory_;
- }
-
- size_t numThreads();
- void setNumThreads(size_t numThreads);
- /*
- * stop() is best effort - there is no guarantee that unexecuted tasks won't
- * be executed before it returns. Specifically, IOThreadPoolExecutor's stop()
- * behaves like join().
- */
- void stop();
- void join();
-
- struct PoolStats {
- PoolStats() : threadCount(0), idleThreadCount(0), activeThreadCount(0),
- pendingTaskCount(0), totalTaskCount(0) {}
- size_t threadCount, idleThreadCount, activeThreadCount;
- uint64_t pendingTaskCount, totalTaskCount;
- };
-
- PoolStats getPoolStats();
-
- struct TaskStats {
- TaskStats() : expired(false), waitTime(0), runTime(0) {}
- bool expired;
- std::chrono::nanoseconds waitTime;
- std::chrono::nanoseconds runTime;
- };
-
- Subscription<TaskStats> subscribeToTaskStats(
- const ObserverPtr<TaskStats>& observer) {
- return taskStatsSubject_->subscribe(observer);
- }
-
- protected:
- // Prerequisite: threadListLock_ writelocked
- void addThreads(size_t n);
- // Prerequisite: threadListLock_ writelocked
- void removeThreads(size_t n, bool isJoin);
-
- struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING Thread {
- explicit Thread(ThreadPoolExecutor* pool)
- : id(nextId++),
- handle(),
- idle(true),
- taskStatsSubject(pool->taskStatsSubject_) {}
-
- virtual ~Thread() {}
-
- static std::atomic<uint64_t> nextId;
- uint64_t id;
- std::thread handle;
- bool idle;
- Baton<> startupBaton;
- std::shared_ptr<Subject<TaskStats>> taskStatsSubject;
- };
-
- typedef std::shared_ptr<Thread> ThreadPtr;
-
- struct Task {
- explicit Task(
- Func&& func,
- std::chrono::milliseconds expiration,
- Func&& expireCallback);
- Func func_;
- TaskStats stats_;
- std::chrono::steady_clock::time_point enqueueTime_;
- std::chrono::milliseconds expiration_;
- Func expireCallback_;
- };
-
- static void runTask(const ThreadPtr& thread, Task&& task);
-
- // The function that will be bound to pool threads. It must call
- // thread->startupBaton.post() when it's ready to consume work.
- virtual void threadRun(ThreadPtr thread) = 0;
-
- // Stop n threads and put their ThreadPtrs in the threadsStopped_ queue
- // Prerequisite: threadListLock_ writelocked
- virtual void stopThreads(size_t n) = 0;
-
- // Create a suitable Thread struct
- virtual ThreadPtr makeThread() {
- return std::make_shared<Thread>(this);
- }
-
- // Prerequisite: threadListLock_ readlocked
- virtual uint64_t getPendingTaskCount() = 0;
-
- class ThreadList {
- public:
- void add(const ThreadPtr& state) {
- auto it = std::lower_bound(vec_.begin(), vec_.end(), state, compare);
- vec_.insert(it, state);
- }
-
- void remove(const ThreadPtr& state) {
- auto itPair = std::equal_range(vec_.begin(), vec_.end(), state, compare);
- CHECK(itPair.first != vec_.end());
- CHECK(std::next(itPair.first) == itPair.second);
- vec_.erase(itPair.first);
- }
-
- const std::vector<ThreadPtr>& get() const {
- return vec_;
- }
-
- private:
- static bool compare(const ThreadPtr& ts1, const ThreadPtr& ts2) {
- return ts1->id < ts2->id;
- }
-
- std::vector<ThreadPtr> vec_;
- };
-
- class StoppedThreadQueue : public BlockingQueue<ThreadPtr> {
- public:
- void add(ThreadPtr item) override;
- ThreadPtr take() override;
- size_t size() override;
-
- private:
- LifoSem sem_;
- std::mutex mutex_;
- std::queue<ThreadPtr> queue_;
- };
-
- std::shared_ptr<ThreadFactory> threadFactory_;
- ThreadList threadList_;
- RWSpinLock threadListLock_;
- StoppedThreadQueue stoppedThreads_;
- std::atomic<bool> isJoin_; // whether the current downsizing is a join
-
- std::shared_ptr<Subject<TaskStats>> taskStatsSubject_;
-};
-
-}} // folly::wangle
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <chrono>
-#include <folly/experimental/wangle/concurrent/Codel.h>
-#include <gtest/gtest.h>
-#include <thread>
-
-TEST(CodelTest, Basic) {
- using std::chrono::milliseconds;
- folly::wangle::Codel c;
- std::this_thread::sleep_for(milliseconds(110));
- // This interval is overloaded
- EXPECT_FALSE(c.overloaded(milliseconds(100)));
- std::this_thread::sleep_for(milliseconds(90));
- // At least two requests must happen in an interval before they will fail
- EXPECT_FALSE(c.overloaded(milliseconds(50)));
- EXPECT_TRUE(c.overloaded(milliseconds(50)));
- std::this_thread::sleep_for(milliseconds(110));
- // Previous interval is overloaded, but 2ms isn't enough to fail
- EXPECT_FALSE(c.overloaded(milliseconds(2)));
- std::this_thread::sleep_for(milliseconds(90));
- // 20 ms > target interval * 2
- EXPECT_TRUE(c.overloaded(milliseconds(20)));
-}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <gtest/gtest.h>
-#include <folly/experimental/wangle/concurrent/GlobalExecutor.h>
-#include <folly/experimental/wangle/concurrent/IOExecutor.h>
-
-using namespace folly::wangle;
-
-TEST(GlobalExecutorTest, GlobalIOExecutor) {
- class DummyExecutor : public IOExecutor {
- public:
- void add(folly::Func f) override {
- count++;
- }
- folly::EventBase* getEventBase() override {
- return nullptr;
- }
- int count{0};
- };
-
- auto f = [](){};
-
- // Don't explode, we should create the default global IOExecutor lazily here.
- getIOExecutor()->add(f);
-
- {
- DummyExecutor dummy;
- setIOExecutor(&dummy);
- getIOExecutor()->add(f);
- // Make sure we were properly installed.
- EXPECT_EQ(1, dummy.count);
- }
-
- // Don't explode, we should restore the default global IOExecutor when dummy
- // is destructed.
- getIOExecutor()->add(f);
-}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <folly/experimental/wangle/concurrent/FutureExecutor.h>
-#include <folly/experimental/wangle/concurrent/ThreadPoolExecutor.h>
-#include <folly/experimental/wangle/concurrent/CPUThreadPoolExecutor.h>
-#include <folly/experimental/wangle/concurrent/IOThreadPoolExecutor.h>
-#include <glog/logging.h>
-#include <gtest/gtest.h>
-
-using namespace folly::wangle;
-using namespace std::chrono;
-
-static folly::Func burnMs(uint64_t ms) {
- return [ms]() { std::this_thread::sleep_for(milliseconds(ms)); };
-}
-
-template <class TPE>
-static void basic() {
- // Create and destroy
- TPE tpe(10);
-}
-
-TEST(ThreadPoolExecutorTest, CPUBasic) {
- basic<CPUThreadPoolExecutor>();
-}
-
-TEST(IOThreadPoolExecutorTest, IOBasic) {
- basic<IOThreadPoolExecutor>();
-}
-
-template <class TPE>
-static void resize() {
- TPE tpe(100);
- EXPECT_EQ(100, tpe.numThreads());
- tpe.setNumThreads(50);
- EXPECT_EQ(50, tpe.numThreads());
- tpe.setNumThreads(150);
- EXPECT_EQ(150, tpe.numThreads());
-}
-
-TEST(ThreadPoolExecutorTest, CPUResize) {
- resize<CPUThreadPoolExecutor>();
-}
-
-TEST(ThreadPoolExecutorTest, IOResize) {
- resize<IOThreadPoolExecutor>();
-}
-
-template <class TPE>
-static void stop() {
- TPE tpe(1);
- std::atomic<int> completed(0);
- auto f = [&](){
- burnMs(10)();
- completed++;
- };
- for (int i = 0; i < 1000; i++) {
- tpe.add(f);
- }
- tpe.stop();
- EXPECT_GT(1000, completed);
-}
-
-// IOThreadPoolExecutor's stop() behaves like join(). Outstanding tasks belong
-// to the event base, will be executed upon its destruction, and cannot be
-// taken back.
-template <>
-void stop<IOThreadPoolExecutor>() {
- IOThreadPoolExecutor tpe(1);
- std::atomic<int> completed(0);
- auto f = [&](){
- burnMs(10)();
- completed++;
- };
- for (int i = 0; i < 10; i++) {
- tpe.add(f);
- }
- tpe.stop();
- EXPECT_EQ(10, completed);
-}
-
-TEST(ThreadPoolExecutorTest, CPUStop) {
- stop<CPUThreadPoolExecutor>();
-}
-
-TEST(ThreadPoolExecutorTest, IOStop) {
- stop<IOThreadPoolExecutor>();
-}
-
-template <class TPE>
-static void join() {
- TPE tpe(10);
- std::atomic<int> completed(0);
- auto f = [&](){
- burnMs(1)();
- completed++;
- };
- for (int i = 0; i < 1000; i++) {
- tpe.add(f);
- }
- tpe.join();
- EXPECT_EQ(1000, completed);
-}
-
-TEST(ThreadPoolExecutorTest, CPUJoin) {
- join<CPUThreadPoolExecutor>();
-}
-
-TEST(ThreadPoolExecutorTest, IOJoin) {
- join<IOThreadPoolExecutor>();
-}
-
-template <class TPE>
-static void resizeUnderLoad() {
- TPE tpe(10);
- std::atomic<int> completed(0);
- auto f = [&](){
- burnMs(1)();
- completed++;
- };
- for (int i = 0; i < 1000; i++) {
- tpe.add(f);
- }
- tpe.setNumThreads(5);
- tpe.setNumThreads(15);
- tpe.join();
- EXPECT_EQ(1000, completed);
-}
-
-TEST(ThreadPoolExecutorTest, CPUResizeUnderLoad) {
- resizeUnderLoad<CPUThreadPoolExecutor>();
-}
-
-TEST(ThreadPoolExecutorTest, IOResizeUnderLoad) {
- resizeUnderLoad<IOThreadPoolExecutor>();
-}
-
-template <class TPE>
-static void poolStats() {
- folly::Baton<> startBaton, endBaton;
- TPE tpe(1);
- auto stats = tpe.getPoolStats();
- EXPECT_EQ(1, stats.threadCount);
- EXPECT_EQ(1, stats.idleThreadCount);
- EXPECT_EQ(0, stats.activeThreadCount);
- EXPECT_EQ(0, stats.pendingTaskCount);
- EXPECT_EQ(0, stats.totalTaskCount);
- tpe.add([&](){ startBaton.post(); endBaton.wait(); });
- tpe.add([&](){});
- startBaton.wait();
- stats = tpe.getPoolStats();
- EXPECT_EQ(1, stats.threadCount);
- EXPECT_EQ(0, stats.idleThreadCount);
- EXPECT_EQ(1, stats.activeThreadCount);
- EXPECT_EQ(1, stats.pendingTaskCount);
- EXPECT_EQ(2, stats.totalTaskCount);
- endBaton.post();
-}
-
-TEST(ThreadPoolExecutorTest, CPUPoolStats) {
- poolStats<CPUThreadPoolExecutor>();
-}
-
-TEST(ThreadPoolExecutorTest, IOPoolStats) {
- poolStats<IOThreadPoolExecutor>();
-}
-
-template <class TPE>
-static void taskStats() {
- TPE tpe(1);
- std::atomic<int> c(0);
- auto s = tpe.subscribeToTaskStats(
- Observer<ThreadPoolExecutor::TaskStats>::create(
- [&](ThreadPoolExecutor::TaskStats stats) {
- int i = c++;
- EXPECT_LT(milliseconds(0), stats.runTime);
- if (i == 1) {
- EXPECT_LT(milliseconds(0), stats.waitTime);
- }
- }));
- tpe.add(burnMs(10));
- tpe.add(burnMs(10));
- tpe.join();
- EXPECT_EQ(2, c);
-}
-
-TEST(ThreadPoolExecutorTest, CPUTaskStats) {
- taskStats<CPUThreadPoolExecutor>();
-}
-
-TEST(ThreadPoolExecutorTest, IOTaskStats) {
- taskStats<IOThreadPoolExecutor>();
-}
-
-template <class TPE>
-static void expiration() {
- TPE tpe(1);
- std::atomic<int> statCbCount(0);
- auto s = tpe.subscribeToTaskStats(
- Observer<ThreadPoolExecutor::TaskStats>::create(
- [&](ThreadPoolExecutor::TaskStats stats) {
- int i = statCbCount++;
- if (i == 0) {
- EXPECT_FALSE(stats.expired);
- } else if (i == 1) {
- EXPECT_TRUE(stats.expired);
- } else {
- FAIL();
- }
- }));
- std::atomic<int> expireCbCount(0);
- auto expireCb = [&] () { expireCbCount++; };
- tpe.add(burnMs(10), seconds(60), expireCb);
- tpe.add(burnMs(10), milliseconds(10), expireCb);
- tpe.join();
- EXPECT_EQ(2, statCbCount);
- EXPECT_EQ(1, expireCbCount);
-}
-
-TEST(ThreadPoolExecutorTest, CPUExpiration) {
- expiration<CPUThreadPoolExecutor>();
-}
-
-TEST(ThreadPoolExecutorTest, IOExpiration) {
- expiration<IOThreadPoolExecutor>();
-}
-
-template <typename TPE>
-static void futureExecutor() {
- FutureExecutor<TPE> fe(2);
- std::atomic<int> c{0};
- fe.addFuture([] () { return makeFuture<int>(42); }).then(
- [&] (Try<int>&& t) {
- c++;
- EXPECT_EQ(42, t.value());
- });
- fe.addFuture([] () { return 100; }).then(
- [&] (Try<int>&& t) {
- c++;
- EXPECT_EQ(100, t.value());
- });
- fe.addFuture([] () { return makeFuture(); }).then(
- [&] (Try<void>&& t) {
- c++;
- EXPECT_NO_THROW(t.value());
- });
- fe.addFuture([] () { return; }).then(
- [&] (Try<void>&& t) {
- c++;
- EXPECT_NO_THROW(t.value());
- });
- fe.addFuture([] () { throw std::runtime_error("oops"); }).then(
- [&] (Try<void>&& t) {
- c++;
- EXPECT_THROW(t.value(), std::runtime_error);
- });
- // Test doing actual async work
- folly::Baton<> baton;
- fe.addFuture([&] () {
- auto p = std::make_shared<Promise<int>>();
- std::thread t([p](){
- burnMs(10)();
- p->setValue(42);
- });
- t.detach();
- return p->getFuture();
- }).then([&] (Try<int>&& t) {
- EXPECT_EQ(42, t.value());
- c++;
- baton.post();
- });
- baton.wait();
- fe.join();
- EXPECT_EQ(6, c);
-}
-
-TEST(ThreadPoolExecutorTest, CPUFuturePool) {
- futureExecutor<CPUThreadPoolExecutor>();
-}
-
-TEST(ThreadPoolExecutorTest, IOFuturePool) {
- futureExecutor<IOThreadPoolExecutor>();
-}
-
-TEST(ThreadPoolExecutorTest, PriorityPreemptionTest) {
- bool tookLopri = false;
- auto completed = 0;
- auto hipri = [&] {
- EXPECT_FALSE(tookLopri);
- completed++;
- };
- auto lopri = [&] {
- tookLopri = true;
- completed++;
- };
- CPUThreadPoolExecutor pool(0, 2);
- for (int i = 0; i < 50; i++) {
- pool.add(lopri, 0);
- }
- for (int i = 0; i < 50; i++) {
- pool.add(hipri, 1);
- }
- pool.setNumThreads(1);
- pool.join();
- EXPECT_EQ(100, completed);
-}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// fbbuild is too dumb to know that .h files in the directory affect
-// our project, unless we have a .cpp file in the target, in the same
-// directory.
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <folly/experimental/wangle/rx/Subject.h>
-#include <folly/experimental/wangle/rx/Subscription.h>
-#include <folly/experimental/wangle/rx/types.h>
-
-#include <folly/RWSpinLock.h>
-#include <folly/SmallLocks.h>
-#include <folly/ThreadLocal.h>
-#include <folly/small_vector.h>
-#include <folly/Executor.h>
-#include <map>
-#include <memory>
-
-namespace folly { namespace wangle {
-
-template <class T, size_t InlineObservers>
-class Observable {
- public:
- Observable() : nextSubscriptionId_{1} {}
-
- // TODO perhaps we want to provide this #5283229
- Observable(Observable&& other) = delete;
-
- virtual ~Observable() {
- if (unsubscriber_) {
- unsubscriber_->disable();
- }
- }
-
- // The next three methods subscribe the given Observer to this Observable.
- //
- // If these are called within an Observer callback, the new observer will not
- // get the current update but will get subsequent updates.
- //
- // subscribe() returns a Subscription object. The observer will continue to
- // get updates until the Subscription is destroyed.
- //
- // observe(ObserverPtr<T>) creates an indefinite subscription
- //
- // observe(Observer<T>*) also creates an indefinite subscription, but the
- // caller is responsible for ensuring that the given Observer outlives this
- // Observable. This might be useful in high performance environments where
- // allocations must be kept to a minimum. Template parameter InlineObservers
- // specifies how many observers can been subscribed inline without any
- // allocations (it's just the size of a folly::small_vector).
- virtual Subscription<T> subscribe(ObserverPtr<T> observer) {
- return subscribeImpl(observer, false);
- }
-
- virtual void observe(ObserverPtr<T> observer) {
- subscribeImpl(observer, true);
- }
-
- virtual void observe(Observer<T>* observer) {
- if (inCallback_ && *inCallback_) {
- if (!newObservers_) {
- newObservers_.reset(new ObserverList());
- }
- newObservers_->push_back(observer);
- } else {
- RWSpinLock::WriteHolder{&observersLock_};
- observers_.push_back(observer);
- }
- }
-
- // TODO unobserve(ObserverPtr<T>), unobserve(Observer<T>*)
-
- /// Returns a new Observable that will call back on the given Scheduler.
- /// The returned Observable must outlive the parent Observable.
-
- // This and subscribeOn should maybe just be a first-class feature of an
- // Observable, rather than making new ones whose lifetimes are tied to their
- // parents. In that case it'd return a reference to this object for
- // chaining.
- ObservablePtr<T> observeOn(SchedulerPtr scheduler) {
- // you're right Hannes, if we have Observable::create we don't need this
- // helper class.
- struct ViaSubject : public Observable<T>
- {
- ViaSubject(SchedulerPtr sched,
- Observable* obs)
- : scheduler_(sched), observable_(obs)
- {}
-
- Subscription<T> subscribe(ObserverPtr<T> o) override {
- return observable_->subscribe(
- Observer<T>::create(
- [=](T val) { scheduler_->add([o, val] { o->onNext(val); }); },
- [=](Error e) { scheduler_->add([o, e] { o->onError(e); }); },
- [=]() { scheduler_->add([o] { o->onCompleted(); }); }));
- }
-
- protected:
- SchedulerPtr scheduler_;
- Observable* observable_;
- };
-
- return std::make_shared<ViaSubject>(scheduler, this);
- }
-
- /// Returns a new Observable that will subscribe to this parent Observable
- /// via the given Scheduler. This can be subtle and confusing at first, see
- /// http://www.introtorx.com/Content/v1.0.10621.0/15_SchedulingAndThreading.html#SubscribeOnObserveOn
- std::unique_ptr<Observable> subscribeOn(SchedulerPtr scheduler) {
- struct Subject_ : public Subject<T> {
- public:
- Subject_(SchedulerPtr s, Observable* o) : scheduler_(s), observable_(o) {
- }
-
- Subscription<T> subscribe(ObserverPtr<T> o) {
- scheduler_->add([=] {
- observable_->subscribe(o);
- });
- return Subscription<T>(nullptr, 0); // TODO
- }
-
- protected:
- SchedulerPtr scheduler_;
- Observable* observable_;
- };
-
- return folly::make_unique<Subject_>(scheduler, this);
- }
-
- protected:
- // Safely execute an operation on each observer. F must take a single
- // Observer<T>* as its argument.
- template <class F>
- void forEachObserver(F f) {
- if (UNLIKELY(!inCallback_)) {
- inCallback_.reset(new bool{false});
- }
- CHECK(!(*inCallback_));
- *inCallback_ = true;
-
- {
- RWSpinLock::ReadHolder rh(observersLock_);
- for (auto o : observers_) {
- f(o);
- }
-
- for (auto& kv : subscribers_) {
- f(kv.second.get());
- }
- }
-
- if (UNLIKELY((newObservers_ && !newObservers_->empty()) ||
- (newSubscribers_ && !newSubscribers_->empty()) ||
- (oldSubscribers_ && !oldSubscribers_->empty()))) {
- {
- RWSpinLock::WriteHolder wh(observersLock_);
- if (newObservers_) {
- for (auto observer : *(newObservers_)) {
- observers_.push_back(observer);
- }
- newObservers_->clear();
- }
- if (newSubscribers_) {
- for (auto& kv : *(newSubscribers_)) {
- subscribers_.insert(std::move(kv));
- }
- newSubscribers_->clear();
- }
- if (oldSubscribers_) {
- for (auto id : *(oldSubscribers_)) {
- subscribers_.erase(id);
- }
- oldSubscribers_->clear();
- }
- }
- }
- *inCallback_ = false;
- }
-
- private:
- Subscription<T> subscribeImpl(ObserverPtr<T> observer, bool indefinite) {
- auto subscription = makeSubscription(indefinite);
- typename SubscriberMap::value_type kv{subscription.id_, std::move(observer)};
- if (inCallback_ && *inCallback_) {
- if (!newSubscribers_) {
- newSubscribers_.reset(new SubscriberMap());
- }
- newSubscribers_->insert(std::move(kv));
- } else {
- RWSpinLock::WriteHolder{&observersLock_};
- subscribers_.insert(std::move(kv));
- }
- return subscription;
- }
-
- class Unsubscriber {
- public:
- explicit Unsubscriber(Observable* observable) : observable_(observable) {
- CHECK(observable_);
- }
-
- void unsubscribe(uint64_t id) {
- CHECK(id > 0);
- RWSpinLock::ReadHolder guard(lock_);
- if (observable_) {
- observable_->unsubscribe(id);
- }
- }
-
- void disable() {
- RWSpinLock::WriteHolder guard(lock_);
- observable_ = nullptr;
- }
-
- private:
- RWSpinLock lock_;
- Observable* observable_;
- };
-
- std::shared_ptr<Unsubscriber> unsubscriber_{nullptr};
- MicroSpinLock unsubscriberLock_{0};
-
- friend class Subscription<T>;
-
- void unsubscribe(uint64_t id) {
- if (inCallback_ && *inCallback_) {
- if (!oldSubscribers_) {
- oldSubscribers_.reset(new std::vector<uint64_t>());
- }
- if (newSubscribers_) {
- auto it = newSubscribers_->find(id);
- if (it != newSubscribers_->end()) {
- newSubscribers_->erase(it);
- return;
- }
- }
- oldSubscribers_->push_back(id);
- } else {
- RWSpinLock::WriteHolder{&observersLock_};
- subscribers_.erase(id);
- }
- }
-
- Subscription<T> makeSubscription(bool indefinite) {
- if (indefinite) {
- return Subscription<T>(nullptr, nextSubscriptionId_++);
- } else {
- if (!unsubscriber_) {
- std::lock_guard<MicroSpinLock> guard(unsubscriberLock_);
- if (!unsubscriber_) {
- unsubscriber_ = std::make_shared<Unsubscriber>(this);
- }
- }
- return Subscription<T>(unsubscriber_, nextSubscriptionId_++);
- }
- }
-
- std::atomic<uint64_t> nextSubscriptionId_;
- RWSpinLock observersLock_;
- folly::ThreadLocalPtr<bool> inCallback_;
-
- typedef folly::small_vector<Observer<T>*, InlineObservers> ObserverList;
- ObserverList observers_;
- folly::ThreadLocalPtr<ObserverList> newObservers_;
-
- typedef std::map<uint64_t, ObserverPtr<T>> SubscriberMap;
- SubscriberMap subscribers_;
- folly::ThreadLocalPtr<SubscriberMap> newSubscribers_;
- folly::ThreadLocalPtr<std::vector<uint64_t>> oldSubscribers_;
-};
-
-}}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <folly/experimental/wangle/rx/types.h>
-#include <functional>
-#include <memory>
-#include <stdexcept>
-#include <folly/Memory.h>
-
-namespace folly { namespace wangle {
-
-template <class T> class FunctionObserver;
-
-/// Observer interface. You can subclass it, or you can just use create()
-/// to use std::functions.
-template <class T>
-struct Observer {
- // These are what it means to be an Observer.
- virtual void onNext(const T&) = 0;
- virtual void onError(Error) = 0;
- virtual void onCompleted() = 0;
-
- virtual ~Observer() = default;
-
- /// Create an Observer with std::function callbacks. Handy to make ad-hoc
- /// Observers with lambdas.
- ///
- /// Templated for maximum perfect forwarding flexibility, but ultimately
- /// whatever you pass in has to implicitly become a std::function for the
- /// same signature as onNext(), onError(), and onCompleted() respectively.
- /// (see the FunctionObserver typedefs)
- template <class N, class E, class C>
- static std::unique_ptr<Observer> create(
- N&& onNextFn, E&& onErrorFn, C&& onCompletedFn)
- {
- return folly::make_unique<FunctionObserver<T>>(
- std::forward<N>(onNextFn),
- std::forward<E>(onErrorFn),
- std::forward<C>(onCompletedFn));
- }
-
- /// Create an Observer with only onNext and onError callbacks.
- /// onCompleted will just be a no-op.
- template <class N, class E>
- static std::unique_ptr<Observer> create(N&& onNextFn, E&& onErrorFn) {
- return folly::make_unique<FunctionObserver<T>>(
- std::forward<N>(onNextFn),
- std::forward<E>(onErrorFn),
- nullptr);
- }
-
- /// Create an Observer with only an onNext callback.
- /// onError and onCompleted will just be no-ops.
- template <class N>
- static std::unique_ptr<Observer> create(N&& onNextFn) {
- return folly::make_unique<FunctionObserver<T>>(
- std::forward<N>(onNextFn),
- nullptr,
- nullptr);
- }
-};
-
-/// An observer that uses std::function callbacks. You don't really want to
-/// make one of these directly - instead use the Observer::create() methods.
-template <class T>
-struct FunctionObserver : public Observer<T> {
- typedef std::function<void(const T&)> OnNext;
- typedef std::function<void(Error)> OnError;
- typedef std::function<void()> OnCompleted;
-
- /// We don't need any fancy overloads of this constructor because that's
- /// what Observer::create() is for.
- template <class N = OnNext, class E = OnError, class C = OnCompleted>
- FunctionObserver(N&& n, E&& e, C&& c)
- : onNext_(std::forward<N>(n)),
- onError_(std::forward<E>(e)),
- onCompleted_(std::forward<C>(c))
- {}
-
- void onNext(const T& val) override {
- if (onNext_) onNext_(val);
- }
-
- void onError(Error e) override {
- if (onError_) onError_(e);
- }
-
- void onCompleted() override {
- if (onCompleted_) onCompleted_();
- }
-
- protected:
- OnNext onNext_;
- OnError onError_;
- OnCompleted onCompleted_;
-};
-
-}}
+++ /dev/null
-Rx is a pattern for "functional reactive programming" that started at
-Microsoft in C#, and has been reimplemented in various languages, notably
-RxJava for JVM languages.
-
-It is basically the plural of Futures (a la Wangle).
-
-
- singular | plural
- +---------------------------------+-----------------------------------
- sync | Foo getData() | std::vector<Foo> getData()
- async | wangle::Future<Foo> getData() | wangle::Observable<Foo> getData()
-
-
-For more on Rx, I recommend these resources:
-
-Netflix blog post (RxJava): http://techblog.netflix.com/2013/02/rxjava-netflix-api.html
-Introduction to Rx eBook (C#): http://www.introtorx.com/content/v1.0.10621.0/01_WhyRx.html
-The RxJava wiki: https://github.com/Netflix/RxJava/wiki
-Netflix QCon presentation: http://www.infoq.com/presentations/netflix-functional-rx
-https://rx.codeplex.com/
-
-There are open source C++ implementations, I haven't looked at them. They
-might be the best way to go rather than writing it NIH-style. I mostly did it
-as an exercise, to think through how closely we might want to integrate
-something like this with Wangle, and to get a feel for how it works in C++.
-
-I haven't even tried to support move-only data in this version. I'm on the
-fence about the usage of shared_ptr. Subject is underdeveloped. A whole rich
-set of operations is obviously missing. I haven't decided how to handle
-subscriptions (and therefore cancellation), but I'm pretty sure C#'s
-"Disposable" is thoroughly un-C++ (opposite of RAII). So for now subscribe
-returns nothing at all and you can't cancel anything ever. The whole thing is
-probably riddled with lifetime corner case bugs that will come out like a
-swarm of angry bees as soon as someone tries an infinite sequence, or tries to
-partially observe a long sequence. I'm pretty sure subscribeOn has a bug that
-I haven't tracked down yet.
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <folly/experimental/wangle/rx/Observable.h>
-#include <folly/experimental/wangle/rx/Observer.h>
-
-namespace folly { namespace wangle {
-
-/// Subject interface. A Subject is both an Observable and an Observer. There
-/// is a default implementation of the Observer methods that just forwards the
-/// observed events to the Subject's observers.
-template <class T>
-struct Subject : public Observable<T>, public Observer<T> {
- void onNext(const T& val) override {
- this->forEachObserver([&](Observer<T>* o){
- o->onNext(val);
- });
- }
- void onError(Error e) override {
- this->forEachObserver([&](Observer<T>* o){
- o->onError(e);
- });
- }
- void onCompleted() override {
- this->forEachObserver([](Observer<T>* o){
- o->onCompleted();
- });
- }
-};
-
-}}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <folly/experimental/wangle/rx/Observable.h>
-
-namespace folly { namespace wangle {
-
-template <class T>
-class Subscription {
- public:
- Subscription() {}
-
- Subscription(const Subscription&) = delete;
-
- Subscription(Subscription&& other) noexcept {
- *this = std::move(other);
- }
-
- Subscription& operator=(Subscription&& other) noexcept {
- unsubscribe();
- unsubscriber_ = std::move(other.unsubscriber_);
- id_ = other.id_;
- other.unsubscriber_ = nullptr;
- other.id_ = 0;
- return *this;
- }
-
- ~Subscription() {
- unsubscribe();
- }
-
- private:
- typedef typename Observable<T>::Unsubscriber Unsubscriber;
-
- Subscription(std::shared_ptr<Unsubscriber> unsubscriber, uint64_t id)
- : unsubscriber_(std::move(unsubscriber)), id_(id) {
- CHECK(id_ > 0);
- }
-
- void unsubscribe() {
- if (unsubscriber_ && id_ > 0) {
- unsubscriber_->unsubscribe(id_);
- id_ = 0;
- unsubscriber_ = nullptr;
- }
- }
-
- std::shared_ptr<Unsubscriber> unsubscriber_;
- uint64_t id_{0};
-
- friend class Observable<T>;
-};
-
-}}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <folly/Benchmark.h>
-#include <folly/experimental/wangle/rx/Observer.h>
-#include <folly/experimental/wangle/rx/Subject.h>
-#include <gflags/gflags.h>
-
-using namespace folly::wangle;
-using folly::BenchmarkSuspender;
-
-static std::unique_ptr<Observer<int>> makeObserver() {
- return Observer<int>::create([&] (int x) {});
-}
-
-void subscribeImpl(uint iters, int N, bool countUnsubscribe) {
- for (uint iter = 0; iter < iters; iter++) {
- BenchmarkSuspender bs;
- Subject<int> subject;
- std::vector<std::unique_ptr<Observer<int>>> observers;
- std::vector<Subscription<int>> subscriptions;
- subscriptions.reserve(N);
- for (int i = 0; i < N; i++) {
- observers.push_back(makeObserver());
- }
- bs.dismiss();
- for (int i = 0; i < N; i++) {
- subscriptions.push_back(subject.subscribe(std::move(observers[i])));
- }
- if (countUnsubscribe) {
- subscriptions.clear();
- }
- bs.rehire();
- }
-}
-
-void subscribeAndUnsubscribe(uint iters, int N) {
- subscribeImpl(iters, N, true);
-}
-
-void subscribe(uint iters, int N) {
- subscribeImpl(iters, N, false);
-}
-
-void observe(uint iters, int N) {
- for (uint iter = 0; iter < iters; iter++) {
- BenchmarkSuspender bs;
- Subject<int> subject;
- std::vector<std::unique_ptr<Observer<int>>> observers;
- for (int i = 0; i < N; i++) {
- observers.push_back(makeObserver());
- }
- bs.dismiss();
- for (int i = 0; i < N; i++) {
- subject.observe(std::move(observers[i]));
- }
- bs.rehire();
- }
-}
-
-void inlineObserve(uint iters, int N) {
- for (uint iter = 0; iter < iters; iter++) {
- BenchmarkSuspender bs;
- Subject<int> subject;
- std::vector<Observer<int>*> observers;
- for (int i = 0; i < N; i++) {
- observers.push_back(makeObserver().release());
- }
- bs.dismiss();
- for (int i = 0; i < N; i++) {
- subject.observe(observers[i]);
- }
- bs.rehire();
- for (int i = 0; i < N; i++) {
- delete observers[i];
- }
- }
-}
-
-void notifySubscribers(uint iters, int N) {
- for (uint iter = 0; iter < iters; iter++) {
- BenchmarkSuspender bs;
- Subject<int> subject;
- std::vector<std::unique_ptr<Observer<int>>> observers;
- std::vector<Subscription<int>> subscriptions;
- subscriptions.reserve(N);
- for (int i = 0; i < N; i++) {
- observers.push_back(makeObserver());
- }
- for (int i = 0; i < N; i++) {
- subscriptions.push_back(subject.subscribe(std::move(observers[i])));
- }
- bs.dismiss();
- subject.onNext(42);
- bs.rehire();
- }
-}
-
-void notifyInlineObservers(uint iters, int N) {
- for (uint iter = 0; iter < iters; iter++) {
- BenchmarkSuspender bs;
- Subject<int> subject;
- std::vector<Observer<int>*> observers;
- for (int i = 0; i < N; i++) {
- observers.push_back(makeObserver().release());
- }
- for (int i = 0; i < N; i++) {
- subject.observe(observers[i]);
- }
- bs.dismiss();
- subject.onNext(42);
- bs.rehire();
- }
-}
-
-BENCHMARK_PARAM(subscribeAndUnsubscribe, 1);
-BENCHMARK_RELATIVE_PARAM(subscribe, 1);
-BENCHMARK_RELATIVE_PARAM(observe, 1);
-BENCHMARK_RELATIVE_PARAM(inlineObserve, 1);
-
-BENCHMARK_DRAW_LINE();
-
-BENCHMARK_PARAM(subscribeAndUnsubscribe, 1000);
-BENCHMARK_RELATIVE_PARAM(subscribe, 1000);
-BENCHMARK_RELATIVE_PARAM(observe, 1000);
-BENCHMARK_RELATIVE_PARAM(inlineObserve, 1000);
-
-BENCHMARK_DRAW_LINE();
-
-BENCHMARK_PARAM(notifySubscribers, 1);
-BENCHMARK_RELATIVE_PARAM(notifyInlineObservers, 1);
-
-BENCHMARK_DRAW_LINE();
-
-BENCHMARK_PARAM(notifySubscribers, 1000);
-BENCHMARK_RELATIVE_PARAM(notifyInlineObservers, 1000);
-
-int main(int argc, char** argv) {
- gflags::ParseCommandLineFlags(&argc, &argv, true);
- folly::runBenchmarks();
- return 0;
-}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <folly/experimental/wangle/rx/Observer.h>
-#include <folly/experimental/wangle/rx/Subject.h>
-#include <gtest/gtest.h>
-
-using namespace folly::wangle;
-
-static std::unique_ptr<Observer<int>> incrementer(int& counter) {
- return Observer<int>::create([&] (int x) {
- counter++;
- });
-}
-
-TEST(RxTest, Observe) {
- Subject<int> subject;
- auto count = 0;
- subject.observe(incrementer(count));
- subject.onNext(1);
- EXPECT_EQ(1, count);
-}
-
-TEST(RxTest, ObserveInline) {
- Subject<int> subject;
- auto count = 0;
- auto o = incrementer(count).release();
- subject.observe(o);
- subject.onNext(1);
- EXPECT_EQ(1, count);
- delete o;
-}
-
-TEST(RxTest, Subscription) {
- Subject<int> subject;
- auto count = 0;
- {
- auto s = subject.subscribe(incrementer(count));
- subject.onNext(1);
- }
- // The subscription has gone out of scope so no one should get this.
- subject.onNext(2);
- EXPECT_EQ(1, count);
-}
-
-TEST(RxTest, SubscriptionMove) {
- Subject<int> subject;
- auto count = 0;
- auto s = subject.subscribe(incrementer(count));
- auto s2 = subject.subscribe(incrementer(count));
- s2 = std::move(s);
- subject.onNext(1);
- Subscription<int> s3(std::move(s2));
- subject.onNext(2);
- EXPECT_EQ(2, count);
-}
-
-TEST(RxTest, SubscriptionOutlivesSubject) {
- Subscription<int> s;
- {
- Subject<int> subject;
- s = subject.subscribe(Observer<int>::create([](int){}));
- }
- // Don't explode when s is destroyed
-}
-
-TEST(RxTest, SubscribeDuringCallback) {
- // A subscriber who was subscribed in the course of a callback should get
- // subsequent updates but not the current update.
- Subject<int> subject;
- int outerCount = 0, innerCount = 0;
- Subscription<int> s1, s2;
- s1 = subject.subscribe(Observer<int>::create([&] (int x) {
- outerCount++;
- s2 = subject.subscribe(incrementer(innerCount));
- }));
- subject.onNext(42);
- subject.onNext(0xDEADBEEF);
- EXPECT_EQ(2, outerCount);
- EXPECT_EQ(1, innerCount);
-}
-
-TEST(RxTest, ObserveDuringCallback) {
- Subject<int> subject;
- int outerCount = 0, innerCount = 0;
- subject.observe(Observer<int>::create([&] (int x) {
- outerCount++;
- subject.observe(incrementer(innerCount));
- }));
- subject.onNext(42);
- subject.onNext(0xDEADBEEF);
- EXPECT_EQ(2, outerCount);
- EXPECT_EQ(1, innerCount);
-}
-
-TEST(RxTest, ObserveInlineDuringCallback) {
- Subject<int> subject;
- int outerCount = 0, innerCount = 0;
- auto innerO = incrementer(innerCount).release();
- auto outerO = Observer<int>::create([&] (int x) {
- outerCount++;
- subject.observe(innerO);
- }).release();
- subject.observe(outerO);
- subject.onNext(42);
- subject.onNext(0xDEADBEEF);
- EXPECT_EQ(2, outerCount);
- EXPECT_EQ(1, innerCount);
- delete innerO;
- delete outerO;
-}
-
-TEST(RxTest, UnsubscribeDuringCallback) {
- // A subscriber who was unsubscribed in the course of a callback should get
- // the current update but not subsequent ones
- Subject<int> subject;
- int count1 = 0, count2 = 0;
- auto s1 = subject.subscribe(incrementer(count1));
- auto s2 = subject.subscribe(Observer<int>::create([&] (int x) {
- count2++;
- s1.~Subscription();
- }));
- subject.onNext(1);
- subject.onNext(2);
- EXPECT_EQ(1, count1);
- EXPECT_EQ(2, count2);
-}
-
-TEST(RxTest, SubscribeUnsubscribeDuringCallback) {
- // A subscriber who was subscribed and unsubscribed in the course of a
- // callback should not get any updates
- Subject<int> subject;
- int outerCount = 0, innerCount = 0;
- auto s2 = subject.subscribe(Observer<int>::create([&] (int x) {
- outerCount++;
- auto s2 = subject.subscribe(incrementer(innerCount));
- }));
- subject.onNext(1);
- subject.onNext(2);
- EXPECT_EQ(2, outerCount);
- EXPECT_EQ(0, innerCount);
-}
-
-// Move only type
-typedef std::unique_ptr<int> MO;
-static MO makeMO() { return folly::make_unique<int>(1); }
-template <typename T>
-static ObserverPtr<T> makeMOObserver() {
- return Observer<T>::create([](const T& mo) {
- EXPECT_EQ(1, *mo);
- });
-}
-
-TEST(RxTest, MoveOnlyRvalue) {
- Subject<MO> subject;
- auto s1 = subject.subscribe(makeMOObserver<MO>());
- auto s2 = subject.subscribe(makeMOObserver<MO>());
- auto mo = makeMO();
- // Can't bind lvalues to rvalue references
- // subject.onNext(mo);
- subject.onNext(std::move(mo));
- subject.onNext(makeMO());
-}
-
-// Copy only type
-struct CO {
- CO() = default;
- CO(const CO&) = default;
- CO(CO&&) = delete;
-};
-
-template <typename T>
-static ObserverPtr<T> makeCOObserver() {
- return Observer<T>::create([](const T& mo) {});
-}
-
-TEST(RxTest, CopyOnly) {
- Subject<CO> subject;
- auto s1 = subject.subscribe(makeCOObserver<CO>());
- CO co;
- subject.onNext(co);
-}
+++ /dev/null
-/*
- * Copyright 2014 Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#pragma once
-
-#include <folly/ExceptionWrapper.h>
-#include <folly/Executor.h>
-
-namespace folly { namespace wangle {
- typedef folly::exception_wrapper Error;
- // The Executor is basically an rx Scheduler (by design). So just
- // alias it.
- typedef std::shared_ptr<folly::Executor> SchedulerPtr;
-
- template <class T, size_t InlineObservers = 3> struct Observable;
- template <class T> struct Observer;
- template <class T> struct Subject;
-
- template <class T> using ObservablePtr = std::shared_ptr<Observable<T>>;
- template <class T> using ObserverPtr = std::shared_ptr<Observer<T>>;
- template <class T> using SubjectPtr = std::shared_ptr<Subject<T>>;
-}}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-namespace folly {
-
-class ClientHelloExtStats {
- public:
- virtual ~ClientHelloExtStats() noexcept {}
-
- // client hello
- virtual void recordAbsentHostname() noexcept = 0;
- virtual void recordMatch() noexcept = 0;
- virtual void recordNotMatch() noexcept = 0;
-};
-
-}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-#include <openssl/dh.h>
-
-// The following was auto-generated by
-// openssl dhparam -C 2048
-DH *get_dh2048()
- {
- static unsigned char dh2048_p[]={
- 0xF8,0x87,0xA5,0x15,0x98,0x35,0x20,0x1E,0xF5,0x81,0xE5,0x95,
- 0x1B,0xE4,0x54,0xEA,0x53,0xF5,0xE7,0x26,0x30,0x03,0x06,0x79,
- 0x3C,0xC1,0x0B,0xAD,0x3B,0x59,0x3C,0x61,0x13,0x03,0x7B,0x02,
- 0x70,0xDE,0xC1,0x20,0x11,0x9E,0x94,0x13,0x50,0xF7,0x62,0xFC,
- 0x99,0x0D,0xC1,0x12,0x6E,0x03,0x95,0xA3,0x57,0xC7,0x3C,0xB8,
- 0x6B,0x40,0x56,0x65,0x70,0xFB,0x7A,0xE9,0x02,0xEC,0xD2,0xB6,
- 0x54,0xD7,0x34,0xAD,0x3D,0x9E,0x11,0x61,0x53,0xBE,0xEA,0xB8,
- 0x17,0x48,0xA8,0xDC,0x70,0xAE,0x65,0x99,0x3F,0x82,0x4C,0xFF,
- 0x6A,0xC9,0xFA,0xB1,0xFA,0xE4,0x4F,0x5D,0xA4,0x05,0xC2,0x8E,
- 0x55,0xC0,0xB1,0x1D,0xCC,0x17,0xF3,0xFA,0x65,0xD8,0x6B,0x09,
- 0x13,0x01,0x2A,0x39,0xF1,0x86,0x73,0xE3,0x7A,0xC8,0xDB,0x7D,
- 0xDA,0x1C,0xA1,0x2D,0xBA,0x2C,0x00,0x6B,0x2C,0x55,0x28,0x2B,
- 0xD5,0xF5,0x3C,0x9F,0x50,0xA7,0xB7,0x28,0x9F,0x22,0xD5,0x3A,
- 0xC4,0x53,0x01,0xC9,0xF3,0x69,0xB1,0x8D,0x01,0x36,0xF8,0xA8,
- 0x89,0xCA,0x2E,0x72,0xBC,0x36,0x3A,0x42,0xC1,0x06,0xD6,0x0E,
- 0xCB,0x4D,0x5C,0x1F,0xE4,0xA1,0x17,0xBF,0x55,0x64,0x1B,0xB4,
- 0x52,0xEC,0x15,0xED,0x32,0xB1,0x81,0x07,0xC9,0x71,0x25,0xF9,
- 0x4D,0x48,0x3D,0x18,0xF4,0x12,0x09,0x32,0xC4,0x0B,0x7A,0x4E,
- 0x83,0xC3,0x10,0x90,0x51,0x2E,0xBE,0x87,0xF9,0xDE,0xB4,0xE6,
- 0x3C,0x29,0xB5,0x32,0x01,0x9D,0x95,0x04,0xBD,0x42,0x89,0xFD,
- 0x21,0xEB,0xE9,0x88,0x5A,0x27,0xBB,0x31,0xC4,0x26,0x99,0xAB,
- 0x8C,0xA1,0x76,0xDB,
- };
- static unsigned char dh2048_g[]={
- 0x02,
- };
- DH *dh;
-
- if ((dh=DH_new()) == nullptr) return(nullptr);
- dh->p=BN_bin2bn(dh2048_p,(int)sizeof(dh2048_p),nullptr);
- dh->g=BN_bin2bn(dh2048_g,(int)sizeof(dh2048_g),nullptr);
- if ((dh->p == nullptr) || (dh->g == nullptr))
- { DH_free(dh); return(nullptr); }
- return(dh);
- }
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#include <folly/experimental/wangle/ssl/PasswordInFile.h>
-
-#include <folly/FileUtil.h>
-
-using namespace std;
-
-namespace folly {
-
-PasswordInFile::PasswordInFile(const string& file)
- : fileName_(file) {
- folly::readFile(file.c_str(), password_);
- auto p = password_.find('\0');
- if (p != std::string::npos) {
- password_.erase(p);
- }
-}
-
-PasswordInFile::~PasswordInFile() {
- OPENSSL_cleanse((char *)password_.data(), password_.length());
-}
-
-}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-#include <folly/io/async/SSLContext.h> // PasswordCollector
-
-namespace folly {
-
-class PasswordInFile: public folly::PasswordCollector {
- public:
- explicit PasswordInFile(const std::string& file);
- ~PasswordInFile();
-
- void getPassword(std::string& password, int size) override {
- password = password_;
- }
-
- const char* getPasswordStr() const {
- return password_.c_str();
- }
-
- std::string describe() const override {
- return fileName_;
- }
-
- protected:
- std::string fileName_;
- std::string password_;
-};
-
-}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-#include <chrono>
-#include <cstdint>
-
-namespace folly {
-
-struct SSLCacheOptions {
- std::chrono::seconds sslCacheTimeout;
- uint64_t maxSSLCacheSize;
- uint64_t sslCacheFlushSize;
-};
-
-}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-#include <folly/io/async/AsyncSSLSocket.h>
-
-namespace folly {
-
-class SSLSessionCacheManager;
-
-/**
- * Interface to be implemented by providers of external session caches
- */
-class SSLCacheProvider {
-public:
- /**
- * Context saved during an external cache request that is used to
- * resume the waiting client.
- */
- typedef struct {
- std::string sessionId;
- SSL_SESSION* session;
- SSLSessionCacheManager* manager;
- AsyncSSLSocket* sslSocket;
- std::unique_ptr<
- folly::DelayedDestruction::DestructorGuard> guard;
- } CacheContext;
-
- virtual ~SSLCacheProvider() {}
-
- /**
- * Store a session in the external cache.
- * @param sessionId Identifier that can be used later to fetch the
- * session with getAsync()
- * @param value Serialized session to store
- * @param expiration Relative expiration time: seconds from now
- * @return true if the storing of the session is initiated successfully
- * (though not necessarily completed; the completion may
- * happen either before or after this method returns), or
- * false if the storing cannot be initiated due to an error.
- */
- virtual bool setAsync(const std::string& sessionId,
- const std::string& value,
- std::chrono::seconds expiration) = 0;
-
- /**
- * Retrieve a session from the external cache. When done, call
- * the cache manager's onGetSuccess() or onGetFailure() callback.
- * @param sessionId Session ID to fetch
- * @param context Data to pass back to the SSLSessionCacheManager
- * in the completion callback
- * @return true if the lookup of the session is initiated successfully
- * (though not necessarily completed; the completion may
- * happen either before or after this method returns), or
- * false if the lookup cannot be initiated due to an error.
- */
- virtual bool getAsync(const std::string& sessionId,
- CacheContext* context) = 0;
-
-};
-
-}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-#include <string>
-#include <folly/io/async/SSLContext.h>
-#include <vector>
-
-/**
- * SSLContextConfig helps to describe the configs/options for
- * a SSL_CTX. For example:
- *
- * 1. Filename of X509, private key and its password.
- * 2. ciphers list
- * 3. NPN list
- * 4. Is session cache enabled?
- * 5. Is it the default X509 in SNI operation?
- * 6. .... and a few more
- */
-namespace folly {
-
-struct SSLContextConfig {
- SSLContextConfig() {}
- ~SSLContextConfig() {}
-
- struct CertificateInfo {
- std::string certPath;
- std::string keyPath;
- std::string passwordPath;
- };
-
- /**
- * Helpers to set/add a certificate
- */
- void setCertificate(const std::string& certPath,
- const std::string& keyPath,
- const std::string& passwordPath) {
- certificates.clear();
- addCertificate(certPath, keyPath, passwordPath);
- }
-
- void addCertificate(const std::string& certPath,
- const std::string& keyPath,
- const std::string& passwordPath) {
- certificates.emplace_back(CertificateInfo{certPath, keyPath, passwordPath});
- }
-
- /**
- * Set the optional list of protocols to advertise via TLS
- * Next Protocol Negotiation. An empty list means NPN is not enabled.
- */
- void setNextProtocols(const std::list<std::string>& inNextProtocols) {
- nextProtocols.clear();
- nextProtocols.push_back({1, inNextProtocols});
- }
-
- typedef std::function<bool(char const* server_name)> SNINoMatchFn;
-
- std::vector<CertificateInfo> certificates;
- folly::SSLContext::SSLVersion sslVersion{
- folly::SSLContext::TLSv1};
- bool sessionCacheEnabled{true};
- bool sessionTicketEnabled{true};
- bool clientHelloParsingEnabled{false};
- std::string sslCiphers{
- "ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:"
- "ECDHE-ECDSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES128-GCM-SHA256:"
- "ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-SHA:ECDHE-RSA-AES256-SHA:"
- "AES128-GCM-SHA256:AES256-GCM-SHA384:AES128-SHA:AES256-SHA:"
- "ECDHE-ECDSA-RC4-SHA:ECDHE-RSA-RC4-SHA:RC4-SHA:RC4-MD5:"
- "ECDHE-RSA-DES-CBC3-SHA:DES-CBC3-SHA"};
- std::string eccCurveName;
- // Ciphers to negotiate if TLS version >= 1.1
- std::string tls11Ciphers{""};
- // Weighted lists of NPN strings to advertise
- std::list<folly::SSLContext::NextProtocolsItem>
- nextProtocols;
- bool isLocalPrivateKey{true};
- // Should this SSLContextConfig be the default for SNI purposes
- bool isDefault{false};
- // Callback function to invoke when there are no matching certificates
- // (will only be invoked once)
- SNINoMatchFn sniNoMatchFn;
- // File containing trusted CA's to validate client certificates
- std::string clientCAFile;
-};
-
-}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#include <folly/experimental/wangle/ssl/SSLContextManager.h>
-
-#include <folly/experimental/wangle/ssl/ClientHelloExtStats.h>
-#include <folly/experimental/wangle/ssl/DHParam.h>
-#include <folly/experimental/wangle/ssl/PasswordInFile.h>
-#include <folly/experimental/wangle/ssl/SSLCacheOptions.h>
-#include <folly/experimental/wangle/ssl/SSLSessionCacheManager.h>
-#include <folly/experimental/wangle/ssl/SSLUtil.h>
-#include <folly/experimental/wangle/ssl/TLSTicketKeyManager.h>
-#include <folly/experimental/wangle/ssl/TLSTicketKeySeeds.h>
-
-#include <folly/Conv.h>
-#include <folly/ScopeGuard.h>
-#include <folly/String.h>
-#include <functional>
-#include <openssl/asn1.h>
-#include <openssl/ssl.h>
-#include <string>
-#include <folly/io/async/EventBase.h>
-
-#define OPENSSL_MISSING_FEATURE(name) \
-do { \
- throw std::runtime_error("missing " #name " support in openssl"); \
-} while(0)
-
-
-using std::string;
-using std::shared_ptr;
-
-/**
- * SSLContextManager helps to create and manage all SSL_CTX,
- * SSLSessionCacheManager and TLSTicketManager for a listening
- * VIP:PORT. (Note, in SNI, a listening VIP:PORT can have >1 SSL_CTX(s)).
- *
- * Other responsibilities:
- * 1. It also handles the SSL_CTX selection after getting the tlsext_hostname
- * in the client hello message.
- *
- * Usage:
- * 1. Each listening VIP:PORT serving SSL should have one SSLContextManager.
- * It maps to Acceptor in the wangle vocabulary.
- *
- * 2. Create a SSLContextConfig object (e.g. by parsing the JSON config).
- *
- * 3. Call SSLContextManager::addSSLContextConfig() which will
- * then create and configure the SSL_CTX
- *
- * Note: Each Acceptor, with SSL support, should have one SSLContextManager to
- * manage all SSL_CTX for the VIP:PORT.
- */
-
-namespace folly {
-
-namespace {
-
-X509* getX509(SSL_CTX* ctx) {
- SSL* ssl = SSL_new(ctx);
- SSL_set_connect_state(ssl);
- X509* x509 = SSL_get_certificate(ssl);
- CRYPTO_add(&x509->references, 1, CRYPTO_LOCK_X509);
- SSL_free(ssl);
- return x509;
-}
-
-void set_key_from_curve(SSL_CTX* ctx, const std::string& curveName) {
-#if OPENSSL_VERSION_NUMBER >= 0x0090800fL
-#ifndef OPENSSL_NO_ECDH
- EC_KEY* ecdh = nullptr;
- int nid;
-
- /*
- * Elliptic-Curve Diffie-Hellman parameters are either "named curves"
- * from RFC 4492 section 5.1.1, or explicitly described curves over
- * binary fields. OpenSSL only supports the "named curves", which provide
- * maximum interoperability.
- */
-
- nid = OBJ_sn2nid(curveName.c_str());
- if (nid == 0) {
- LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
- return;
- }
- ecdh = EC_KEY_new_by_curve_name(nid);
- if (ecdh == nullptr) {
- LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
- return;
- }
-
- SSL_CTX_set_tmp_ecdh(ctx, ecdh);
- EC_KEY_free(ecdh);
-#endif
-#endif
-}
-
-// Helper to create TLSTicketKeyManger and aware of the needed openssl
-// version/feature.
-std::unique_ptr<TLSTicketKeyManager> createTicketManagerHelper(
- std::shared_ptr<folly::SSLContext> ctx,
- const TLSTicketKeySeeds* ticketSeeds,
- const SSLContextConfig& ctxConfig,
- SSLStats* stats) {
-
- std::unique_ptr<TLSTicketKeyManager> ticketManager;
-#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB
- if (ticketSeeds && ctxConfig.sessionTicketEnabled) {
- ticketManager = folly::make_unique<TLSTicketKeyManager>(ctx.get(), stats);
- ticketManager->setTLSTicketKeySeeds(
- ticketSeeds->oldSeeds,
- ticketSeeds->currentSeeds,
- ticketSeeds->newSeeds);
- } else {
- ctx->setOptions(SSL_OP_NO_TICKET);
- }
-#else
- if (ticketSeeds && ctxConfig.sessionTicketEnabled) {
- OPENSSL_MISSING_FEATURE(TLSTicket);
- }
-#endif
- return ticketManager;
-}
-
-std::string flattenList(const std::list<std::string>& list) {
- std::string s;
- bool first = true;
- for (auto& item : list) {
- if (first) {
- first = false;
- } else {
- s.append(", ");
- }
- s.append(item);
- }
- return s;
-}
-
-}
-
-SSLContextManager::~SSLContextManager() {}
-
-SSLContextManager::SSLContextManager(
- EventBase* eventBase,
- const std::string& vipName,
- bool strict,
- SSLStats* stats) :
- stats_(stats),
- eventBase_(eventBase),
- strict_(strict) {
-}
-
-void SSLContextManager::addSSLContextConfig(
- const SSLContextConfig& ctxConfig,
- const SSLCacheOptions& cacheOptions,
- const TLSTicketKeySeeds* ticketSeeds,
- const folly::SocketAddress& vipAddress,
- const std::shared_ptr<SSLCacheProvider>& externalCache) {
-
- unsigned numCerts = 0;
- std::string commonName;
- std::string lastCertPath;
- std::unique_ptr<std::list<std::string>> subjectAltName;
- auto sslCtx = std::make_shared<SSLContext>(ctxConfig.sslVersion);
- for (const auto& cert : ctxConfig.certificates) {
- try {
- sslCtx->loadCertificate(cert.certPath.c_str());
- } catch (const std::exception& ex) {
- // The exception isn't very useful without the certificate path name,
- // so throw a new exception that includes the path to the certificate.
- string msg = folly::to<string>("error loading SSL certificate ",
- cert.certPath, ": ",
- folly::exceptionStr(ex));
- LOG(ERROR) << msg;
- throw std::runtime_error(msg);
- }
-
- // Verify that the Common Name and (if present) Subject Alternative Names
- // are the same for all the certs specified for the SSL context.
- numCerts++;
- X509* x509 = getX509(sslCtx->getSSLCtx());
- auto guard = folly::makeGuard([x509] { X509_free(x509); });
- auto cn = SSLUtil::getCommonName(x509);
- if (!cn) {
- throw std::runtime_error(folly::to<string>("Cannot get CN for X509 ",
- cert.certPath));
- }
- auto altName = SSLUtil::getSubjectAltName(x509);
- VLOG(2) << "cert " << cert.certPath << " CN: " << *cn;
- if (altName) {
- altName->sort();
- VLOG(2) << "cert " << cert.certPath << " SAN: " << flattenList(*altName);
- } else {
- VLOG(2) << "cert " << cert.certPath << " SAN: " << "{none}";
- }
- if (numCerts == 1) {
- commonName = *cn;
- subjectAltName = std::move(altName);
- } else {
- if (commonName != *cn) {
- throw std::runtime_error(folly::to<string>("X509 ", cert.certPath,
- " does not have same CN as ",
- lastCertPath));
- }
- if (altName == nullptr) {
- if (subjectAltName != nullptr) {
- throw std::runtime_error(folly::to<string>("X509 ", cert.certPath,
- " does not have same SAN as ",
- lastCertPath));
- }
- } else {
- if ((subjectAltName == nullptr) || (*altName != *subjectAltName)) {
- throw std::runtime_error(folly::to<string>("X509 ", cert.certPath,
- " does not have same SAN as ",
- lastCertPath));
- }
- }
- }
- lastCertPath = cert.certPath;
-
- // TODO t4438250 - Add ECDSA support to the crypto_ssl offload server
- // so we can avoid storing the ECDSA private key in the
- // address space of the Internet-facing process. For
- // now, if cert name includes "-EC" to denote elliptic
- // curve, we load its private key even if the server as
- // a whole has been configured for async crypto.
- if (ctxConfig.isLocalPrivateKey ||
- (cert.certPath.find("-EC") != std::string::npos)) {
- // The private key lives in the same process
-
- // This needs to be called before loadPrivateKey().
- if (!cert.passwordPath.empty()) {
- auto sslPassword = std::make_shared<PasswordInFile>(cert.passwordPath);
- sslCtx->passwordCollector(sslPassword);
- }
-
- try {
- sslCtx->loadPrivateKey(cert.keyPath.c_str());
- } catch (const std::exception& ex) {
- // Throw an error that includes the key path, so the user can tell
- // which key had a problem.
- string msg = folly::to<string>("error loading private SSL key ",
- cert.keyPath, ": ",
- folly::exceptionStr(ex));
- LOG(ERROR) << msg;
- throw std::runtime_error(msg);
- }
- }
- }
- if (!ctxConfig.isLocalPrivateKey) {
- enableAsyncCrypto(sslCtx);
- }
-
- // Let the server pick the highest performing cipher from among the client's
- // choices.
- //
- // Let's use a unique private key for all DH key exchanges.
- //
- // Because some old implementations choke on empty fragments, most SSL
- // applications disable them (it's part of SSL_OP_ALL). This
- // will improve performance and decrease write buffer fragmentation.
- sslCtx->setOptions(SSL_OP_CIPHER_SERVER_PREFERENCE |
- SSL_OP_SINGLE_DH_USE |
- SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS);
-
- // Configure SSL ciphers list
- if (!ctxConfig.tls11Ciphers.empty()) {
- // FIXME: create a dummy SSL_CTX for cipher testing purpose? It can
- // remove the ordering dependency
-
- // Test to see if the specified TLS1.1 ciphers are valid. Note that
- // these will be overwritten by the ciphers() call below.
- sslCtx->setCiphersOrThrow(ctxConfig.tls11Ciphers);
- }
-
- // Important that we do this *after* checking the TLS1.1 ciphers above,
- // since we test their validity by actually setting them.
- sslCtx->ciphers(ctxConfig.sslCiphers);
-
- // Use a fix DH param
- DH* dh = get_dh2048();
- SSL_CTX_set_tmp_dh(sslCtx->getSSLCtx(), dh);
- DH_free(dh);
-
- const string& curve = ctxConfig.eccCurveName;
- if (!curve.empty()) {
- set_key_from_curve(sslCtx->getSSLCtx(), curve);
- }
-
- if (!ctxConfig.clientCAFile.empty()) {
- try {
- sslCtx->setVerificationOption(SSLContext::VERIFY_REQ_CLIENT_CERT);
- sslCtx->loadTrustedCertificates(ctxConfig.clientCAFile.c_str());
- sslCtx->loadClientCAList(ctxConfig.clientCAFile.c_str());
- } catch (const std::exception& ex) {
- string msg = folly::to<string>("error loading client CA",
- ctxConfig.clientCAFile, ": ",
- folly::exceptionStr(ex));
- LOG(ERROR) << msg;
- throw std::runtime_error(msg);
- }
- }
-
- // - start - SSL session cache config
- // the internal cache never does what we want (per-thread-per-vip).
- // Disable it. SSLSessionCacheManager will set it appropriately.
- SSL_CTX_set_session_cache_mode(sslCtx->getSSLCtx(), SSL_SESS_CACHE_OFF);
- SSL_CTX_set_timeout(sslCtx->getSSLCtx(),
- cacheOptions.sslCacheTimeout.count());
- std::unique_ptr<SSLSessionCacheManager> sessionCacheManager;
- if (ctxConfig.sessionCacheEnabled &&
- cacheOptions.maxSSLCacheSize > 0 &&
- cacheOptions.sslCacheFlushSize > 0) {
- sessionCacheManager =
- folly::make_unique<SSLSessionCacheManager>(
- cacheOptions.maxSSLCacheSize,
- cacheOptions.sslCacheFlushSize,
- sslCtx.get(),
- vipAddress,
- commonName,
- eventBase_,
- stats_,
- externalCache);
- }
- // - end - SSL session cache config
-
- std::unique_ptr<TLSTicketKeyManager> ticketManager =
- createTicketManagerHelper(sslCtx, ticketSeeds, ctxConfig, stats_);
-
- // finalize sslCtx setup by the individual features supported by openssl
- ctxSetupByOpensslFeature(sslCtx, ctxConfig);
-
- try {
- insert(sslCtx,
- std::move(sessionCacheManager),
- std::move(ticketManager),
- ctxConfig.isDefault);
- } catch (const std::exception& ex) {
- string msg = folly::to<string>("Error adding certificate : ",
- folly::exceptionStr(ex));
- LOG(ERROR) << msg;
- throw std::runtime_error(msg);
- }
-
-}
-
-#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK
-SSLContext::ServerNameCallbackResult
-SSLContextManager::serverNameCallback(SSL* ssl) {
- shared_ptr<SSLContext> ctx;
-
- const char* sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
- if (!sn) {
- VLOG(6) << "Server Name (tlsext_hostname) is missing";
- if (clientHelloTLSExtStats_) {
- clientHelloTLSExtStats_->recordAbsentHostname();
- }
- return SSLContext::SERVER_NAME_NOT_FOUND;
- }
- size_t snLen = strlen(sn);
- VLOG(6) << "Server Name (SNI TLS extension): '" << sn << "' ";
-
- // FIXME: This code breaks the abstraction. Suggestion?
- AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl);
- CHECK(sslSocket);
-
- DNString dnstr(sn, snLen);
-
- uint32_t count = 0;
- do {
- // Try exact match first
- ctx = getSSLCtx(dnstr);
- if (ctx) {
- sslSocket->switchServerSSLContext(ctx);
- if (clientHelloTLSExtStats_) {
- clientHelloTLSExtStats_->recordMatch();
- }
- return SSLContext::SERVER_NAME_FOUND;
- }
-
- ctx = getSSLCtxBySuffix(dnstr);
- if (ctx) {
- sslSocket->switchServerSSLContext(ctx);
- if (clientHelloTLSExtStats_) {
- clientHelloTLSExtStats_->recordMatch();
- }
- return SSLContext::SERVER_NAME_FOUND;
- }
-
- // Give the noMatchFn one chance to add the correct cert
- }
- while (count++ == 0 && noMatchFn_ && noMatchFn_(sn));
-
- VLOG(6) << folly::stringPrintf("Cannot find a SSL_CTX for \"%s\"", sn);
-
- if (clientHelloTLSExtStats_) {
- clientHelloTLSExtStats_->recordNotMatch();
- }
- return SSLContext::SERVER_NAME_NOT_FOUND;
-}
-#endif
-
-// Consolidate all SSL_CTX setup which depends on openssl version/feature
-void
-SSLContextManager::ctxSetupByOpensslFeature(
- shared_ptr<folly::SSLContext> sslCtx,
- const SSLContextConfig& ctxConfig) {
- // Disable compression - profiling shows this to be very expensive in
- // terms of CPU and memory consumption.
- //
-#ifdef SSL_OP_NO_COMPRESSION
- sslCtx->setOptions(SSL_OP_NO_COMPRESSION);
-#endif
-
- // Enable early release of SSL buffers to reduce the memory footprint
-#ifdef SSL_MODE_RELEASE_BUFFERS
- sslCtx->getSSLCtx()->mode |= SSL_MODE_RELEASE_BUFFERS;
-#endif
-#ifdef SSL_MODE_EARLY_RELEASE_BBIO
- sslCtx->getSSLCtx()->mode |= SSL_MODE_EARLY_RELEASE_BBIO;
-#endif
-
- // This number should (probably) correspond to HTTPSession::kMaxReadSize
- // For now, this number must also be large enough to accommodate our
- // largest certificate, because some older clients (IE6/7) require the
- // cert to be in a single fragment.
-#ifdef SSL_CTRL_SET_MAX_SEND_FRAGMENT
- SSL_CTX_set_max_send_fragment(sslCtx->getSSLCtx(), 8000);
-#endif
-
- // Specify cipher(s) to be used for TLS1.1 client
- if (!ctxConfig.tls11Ciphers.empty()) {
-#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK
- // Specified TLS1.1 ciphers are valid
- sslCtx->addClientHelloCallback(
- std::bind(
- &SSLContext::switchCiphersIfTLS11,
- sslCtx.get(),
- std::placeholders::_1,
- ctxConfig.tls11Ciphers
- )
- );
-#else
- OPENSSL_MISSING_FEATURE(SNI);
-#endif
- }
-
- // NPN (Next Protocol Negotiation)
- if (!ctxConfig.nextProtocols.empty()) {
-#ifdef OPENSSL_NPN_NEGOTIATED
- sslCtx->setRandomizedAdvertisedNextProtocols(ctxConfig.nextProtocols);
-#else
- OPENSSL_MISSING_FEATURE(NPN);
-#endif
- }
-
- // SNI
-#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK
- noMatchFn_ = ctxConfig.sniNoMatchFn;
- if (ctxConfig.isDefault) {
- if (defaultCtx_) {
- throw std::runtime_error(">1 X509 is set as default");
- }
-
- defaultCtx_ = sslCtx;
- defaultCtx_->setServerNameCallback(
- std::bind(&SSLContextManager::serverNameCallback, this,
- std::placeholders::_1));
- }
-#else
- if (ctxs_.size() > 1) {
- OPENSSL_MISSING_FEATURE(SNI);
- }
-#endif
-}
-
-void
-SSLContextManager::insert(shared_ptr<SSLContext> sslCtx,
- std::unique_ptr<SSLSessionCacheManager> smanager,
- std::unique_ptr<TLSTicketKeyManager> tmanager,
- bool defaultFallback) {
- X509* x509 = getX509(sslCtx->getSSLCtx());
- auto guard = folly::makeGuard([x509] { X509_free(x509); });
- auto cn = SSLUtil::getCommonName(x509);
- if (!cn) {
- throw std::runtime_error("Cannot get CN");
- }
-
- /**
- * Some notes from RFC 2818. Only for future quick references in case of bugs
- *
- * RFC 2818 section 3.1:
- * "......
- * If a subjectAltName extension of type dNSName is present, that MUST
- * be used as the identity. Otherwise, the (most specific) Common Name
- * field in the Subject field of the certificate MUST be used. Although
- * the use of the Common Name is existing practice, it is deprecated and
- * Certification Authorities are encouraged to use the dNSName instead.
- * ......
- * In some cases, the URI is specified as an IP address rather than a
- * hostname. In this case, the iPAddress subjectAltName must be present
- * in the certificate and must exactly match the IP in the URI.
- * ......"
- */
-
- // Not sure if we ever get this kind of X509...
- // If we do, assume '*' is always in the CN and ignore all subject alternative
- // names.
- if (cn->length() == 1 && (*cn)[0] == '*') {
- if (!defaultFallback) {
- throw std::runtime_error("STAR X509 is not the default");
- }
- ctxs_.emplace_back(sslCtx);
- sessionCacheManagers_.emplace_back(std::move(smanager));
- ticketManagers_.emplace_back(std::move(tmanager));
- return;
- }
-
- // Insert by CN
- insertSSLCtxByDomainName(cn->c_str(), cn->length(), sslCtx);
-
- // Insert by subject alternative name(s)
- auto altNames = SSLUtil::getSubjectAltName(x509);
- if (altNames) {
- for (auto& name : *altNames) {
- insertSSLCtxByDomainName(name.c_str(), name.length(), sslCtx);
- }
- }
-
- ctxs_.emplace_back(sslCtx);
- sessionCacheManagers_.emplace_back(std::move(smanager));
- ticketManagers_.emplace_back(std::move(tmanager));
-}
-
-void
-SSLContextManager::insertSSLCtxByDomainName(const char* dn, size_t len,
- shared_ptr<SSLContext> sslCtx) {
- try {
- insertSSLCtxByDomainNameImpl(dn, len, sslCtx);
- } catch (const std::runtime_error& ex) {
- if (strict_) {
- throw ex;
- } else {
- LOG(ERROR) << ex.what() << " DN=" << dn;
- }
- }
-}
-void
-SSLContextManager::insertSSLCtxByDomainNameImpl(const char* dn, size_t len,
- shared_ptr<SSLContext> sslCtx)
-{
- VLOG(4) <<
- folly::stringPrintf("Adding CN/Subject-alternative-name \"%s\" for "
- "SNI search", dn);
-
- // Only support wildcard domains which are prefixed exactly by "*." .
- // "*" appearing at other locations is not accepted.
-
- if (len > 2 && dn[0] == '*') {
- if (dn[1] == '.') {
- // skip the first '*'
- dn++;
- len--;
- } else {
- throw std::runtime_error(
- "Invalid wildcard CN/subject-alternative-name \"" + std::string(dn) + "\" "
- "(only allow character \".\" after \"*\"");
- }
- }
-
- if (len == 1 && *dn == '.') {
- throw std::runtime_error("X509 has only '.' in the CN or subject alternative name "
- "(after removing any preceding '*')");
- }
-
- if (strchr(dn, '*')) {
- throw std::runtime_error("X509 has '*' in the the CN or subject alternative name "
- "(after removing any preceding '*')");
- }
-
- DNString dnstr(dn, len);
- const auto v = dnMap_.find(dnstr);
- if (v == dnMap_.end()) {
- dnMap_.emplace(dnstr, sslCtx);
- } else if (v->second == sslCtx) {
- VLOG(6)<< "Duplicate CN or subject alternative name found in the same X509."
- " Ignore the later name.";
- } else {
- throw std::runtime_error("Duplicate CN or subject alternative name found: \"" +
- std::string(dnstr.c_str()) + "\"");
- }
-}
-
-shared_ptr<SSLContext>
-SSLContextManager::getSSLCtxBySuffix(const DNString& dnstr) const
-{
- size_t dot;
-
- if ((dot = dnstr.find_first_of(".")) != DNString::npos) {
- DNString suffixDNStr(dnstr, dot);
- const auto v = dnMap_.find(suffixDNStr);
- if (v != dnMap_.end()) {
- VLOG(6) << folly::stringPrintf("\"%s\" is a willcard match to \"%s\"",
- dnstr.c_str(), suffixDNStr.c_str());
- return v->second;
- }
- }
-
- VLOG(6) << folly::stringPrintf("\"%s\" is not a wildcard match",
- dnstr.c_str());
- return shared_ptr<SSLContext>();
-}
-
-shared_ptr<SSLContext>
-SSLContextManager::getSSLCtx(const DNString& dnstr) const
-{
- const auto v = dnMap_.find(dnstr);
- if (v == dnMap_.end()) {
- VLOG(6) << folly::stringPrintf("\"%s\" is not an exact match",
- dnstr.c_str());
- return shared_ptr<SSLContext>();
- } else {
- VLOG(6) << folly::stringPrintf("\"%s\" is an exact match", dnstr.c_str());
- return v->second;
- }
-}
-
-shared_ptr<SSLContext>
-SSLContextManager::getDefaultSSLCtx() const {
- return defaultCtx_;
-}
-
-void
-SSLContextManager::reloadTLSTicketKeys(
- const std::vector<std::string>& oldSeeds,
- const std::vector<std::string>& currentSeeds,
- const std::vector<std::string>& newSeeds) {
-#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB
- for (auto& tmgr: ticketManagers_) {
- tmgr->setTLSTicketKeySeeds(oldSeeds, currentSeeds, newSeeds);
- }
-#endif
-}
-
-} // namespace
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-#include <folly/io/async/EventBase.h>
-#include <folly/io/async/SSLContext.h>
-
-#include <glog/logging.h>
-#include <list>
-#include <memory>
-#include <folly/experimental/wangle/ssl/SSLContextConfig.h>
-#include <folly/experimental/wangle/ssl/SSLSessionCacheManager.h>
-#include <folly/experimental/wangle/ssl/TLSTicketKeySeeds.h>
-#include <folly/experimental/wangle/acceptor/DomainNameMisc.h>
-#include <vector>
-
-namespace folly {
-
-class SocketAddress;
-class SSLContext;
-class ClientHelloExtStats;
-class SSLCacheOptions;
-class SSLStats;
-class TLSTicketKeyManager;
-class TLSTicketKeySeeds;
-
-class SSLContextManager {
- public:
-
- explicit SSLContextManager(EventBase* eventBase,
- const std::string& vipName, bool strict,
- SSLStats* stats);
- virtual ~SSLContextManager();
-
- /**
- * Add a new X509 to SSLContextManager. The details of a X509
- * is passed as a SSLContextConfig object.
- *
- * @param ctxConfig Details of a X509, its private key, password, etc.
- * @param cacheOptions Options for how to do session caching.
- * @param ticketSeeds If non-null, the initial ticket key seeds to use.
- * @param vipAddress Which VIP are the X509(s) used for? It is only for
- * for user friendly log message
- * @param externalCache Optional external provider for the session cache;
- * may be null
- */
- void addSSLContextConfig(
- const SSLContextConfig& ctxConfig,
- const SSLCacheOptions& cacheOptions,
- const TLSTicketKeySeeds* ticketSeeds,
- const folly::SocketAddress& vipAddress,
- const std::shared_ptr<SSLCacheProvider> &externalCache);
-
- /**
- * Get the default SSL_CTX for a VIP
- */
- std::shared_ptr<SSLContext>
- getDefaultSSLCtx() const;
-
- /**
- * Search by the _one_ level up subdomain
- */
- std::shared_ptr<SSLContext>
- getSSLCtxBySuffix(const DNString& dnstr) const;
-
- /**
- * Search by the full-string domain name
- */
- std::shared_ptr<SSLContext>
- getSSLCtx(const DNString& dnstr) const;
-
- /**
- * Insert a SSLContext by domain name.
- */
- void insertSSLCtxByDomainName(
- const char* dn,
- size_t len,
- std::shared_ptr<SSLContext> sslCtx);
-
- void insertSSLCtxByDomainNameImpl(
- const char* dn,
- size_t len,
- std::shared_ptr<SSLContext> sslCtx);
-
- void reloadTLSTicketKeys(const std::vector<std::string>& oldSeeds,
- const std::vector<std::string>& currentSeeds,
- const std::vector<std::string>& newSeeds);
-
- /**
- * SSLContextManager only collects SNI stats now
- */
-
- void setClientHelloExtStats(ClientHelloExtStats* stats) {
- clientHelloTLSExtStats_ = stats;
- }
-
- protected:
- virtual void enableAsyncCrypto(
- const std::shared_ptr<SSLContext>& sslCtx) {
- LOG(FATAL) << "Unsupported in base SSLContextManager";
- }
- SSLStats* stats_{nullptr};
-
- private:
- SSLContextManager(const SSLContextManager&) = delete;
-
- void ctxSetupByOpensslFeature(
- std::shared_ptr<SSLContext> sslCtx,
- const SSLContextConfig& ctxConfig);
-
- /**
- * Callback function from openssl to find the right X509 to
- * use during SSL handshake
- */
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && \
- !defined(OPENSSL_NO_TLSEXT) && \
- defined(SSL_CTRL_SET_TLSEXT_SERVERNAME_CB)
-# define PROXYGEN_HAVE_SERVERNAMECALLBACK
- SSLContext::ServerNameCallbackResult
- serverNameCallback(SSL* ssl);
-#endif
-
- /**
- * The following functions help to maintain the data structure for
- * domain name matching in SNI. Some notes:
- *
- * 1. It is a best match.
- *
- * 2. It allows wildcard CN and wildcard subject alternative name in a X509.
- * The wildcard name must be _prefixed_ by '*.'. It errors out whenever
- * it sees '*' in any other locations.
- *
- * 3. It uses one std::unordered_map<DomainName, SSL_CTX> object to
- * do this. For wildcard name like "*.facebook.com", ".facebook.com"
- * is used as the key.
- *
- * 4. After getting tlsext_hostname from the client hello message, it
- * will do a full string search first and then try one level up to
- * match any wildcard name (if any) in the X509.
- * [Note, browser also only looks one level up when matching the requesting
- * domain name with the wildcard name in the server X509].
- */
-
- void insert(
- std::shared_ptr<SSLContext> sslCtx,
- std::unique_ptr<SSLSessionCacheManager> cmanager,
- std::unique_ptr<TLSTicketKeyManager> tManager,
- bool defaultFallback);
-
- /**
- * Container to own the SSLContext, SSLSessionCacheManager and
- * TLSTicketKeyManager.
- */
- std::vector<std::shared_ptr<SSLContext>> ctxs_;
- std::vector<std::unique_ptr<SSLSessionCacheManager>>
- sessionCacheManagers_;
- std::vector<std::unique_ptr<TLSTicketKeyManager>> ticketManagers_;
-
- std::shared_ptr<SSLContext> defaultCtx_;
-
- /**
- * Container to store the (DomainName -> SSL_CTX) mapping
- */
- std::unordered_map<
- DNString,
- std::shared_ptr<SSLContext>,
- DNStringHash> dnMap_;
-
- EventBase* eventBase_;
- ClientHelloExtStats* clientHelloTLSExtStats_{nullptr};
- SSLContextConfig::SNINoMatchFn noMatchFn_;
- bool strict_{true};
-};
-
-} // namespace
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#include <folly/experimental/wangle/ssl/SSLSessionCacheManager.h>
-
-#include <folly/experimental/wangle/ssl/SSLCacheProvider.h>
-#include <folly/experimental/wangle/ssl/SSLStats.h>
-#include <folly/experimental/wangle/ssl/SSLUtil.h>
-
-#include <folly/io/async/EventBase.h>
-
-#ifndef NO_LIB_GFLAGS
-#include <gflags/gflags.h>
-#endif
-
-using std::string;
-using std::shared_ptr;
-
-namespace {
-
-const uint32_t NUM_CACHE_BUCKETS = 16;
-
-// We use the default ID generator which fills the maximum ID length
-// for the protocol. 16 bytes for SSLv2 or 32 for SSLv3+
-const int MIN_SESSION_ID_LENGTH = 16;
-
-}
-
-#ifndef NO_LIB_GFLAGS
-DEFINE_bool(dcache_unit_test, false, "All VIPs share one session cache");
-#else
-const bool FLAGS_dcache_unit_test = false;
-#endif
-
-namespace folly {
-
-
-int SSLSessionCacheManager::sExDataIndex_ = -1;
-shared_ptr<ShardedLocalSSLSessionCache> SSLSessionCacheManager::sCache_;
-std::mutex SSLSessionCacheManager::sCacheLock_;
-
-LocalSSLSessionCache::LocalSSLSessionCache(uint32_t maxCacheSize,
- uint32_t cacheCullSize)
- : sessionCache(maxCacheSize, cacheCullSize) {
- sessionCache.setPruneHook(std::bind(
- &LocalSSLSessionCache::pruneSessionCallback,
- this, std::placeholders::_1,
- std::placeholders::_2));
-}
-
-void LocalSSLSessionCache::pruneSessionCallback(const string& sessionId,
- SSL_SESSION* session) {
- VLOG(4) << "Free SSL session from local cache; id="
- << SSLUtil::hexlify(sessionId);
- SSL_SESSION_free(session);
- ++removedSessions_;
-}
-
-
-// SSLSessionCacheManager implementation
-
-SSLSessionCacheManager::SSLSessionCacheManager(
- uint32_t maxCacheSize,
- uint32_t cacheCullSize,
- SSLContext* ctx,
- const folly::SocketAddress& sockaddr,
- const string& context,
- EventBase* eventBase,
- SSLStats* stats,
- const std::shared_ptr<SSLCacheProvider>& externalCache):
- ctx_(ctx),
- stats_(stats),
- externalCache_(externalCache) {
-
- SSL_CTX* sslCtx = ctx->getSSLCtx();
-
- SSLUtil::getSSLCtxExIndex(&sExDataIndex_);
-
- SSL_CTX_set_ex_data(sslCtx, sExDataIndex_, this);
- SSL_CTX_sess_set_new_cb(sslCtx, SSLSessionCacheManager::newSessionCallback);
- SSL_CTX_sess_set_get_cb(sslCtx, SSLSessionCacheManager::getSessionCallback);
- SSL_CTX_sess_set_remove_cb(sslCtx,
- SSLSessionCacheManager::removeSessionCallback);
- if (!FLAGS_dcache_unit_test && !context.empty()) {
- // Use the passed in context
- SSL_CTX_set_session_id_context(sslCtx, (const uint8_t *)context.data(),
- std::min((int)context.length(),
- SSL_MAX_SSL_SESSION_ID_LENGTH));
- }
-
- SSL_CTX_set_session_cache_mode(sslCtx, SSL_SESS_CACHE_NO_INTERNAL
- | SSL_SESS_CACHE_SERVER);
-
- localCache_ = SSLSessionCacheManager::getLocalCache(maxCacheSize,
- cacheCullSize);
-
- VLOG(2) << "On VipID=" << sockaddr.describe() << " context=" << context;
-}
-
-SSLSessionCacheManager::~SSLSessionCacheManager() {
-}
-
-void SSLSessionCacheManager::shutdown() {
- std::lock_guard<std::mutex> g(sCacheLock_);
- sCache_.reset();
-}
-
-shared_ptr<ShardedLocalSSLSessionCache> SSLSessionCacheManager::getLocalCache(
- uint32_t maxCacheSize,
- uint32_t cacheCullSize) {
-
- std::lock_guard<std::mutex> g(sCacheLock_);
- if (!sCache_) {
- sCache_.reset(new ShardedLocalSSLSessionCache(NUM_CACHE_BUCKETS,
- maxCacheSize,
- cacheCullSize));
- }
- return sCache_;
-}
-
-int SSLSessionCacheManager::newSessionCallback(SSL* ssl, SSL_SESSION* session) {
- SSLSessionCacheManager* manager = nullptr;
- SSL_CTX* ctx = SSL_get_SSL_CTX(ssl);
- manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
-
- if (manager == nullptr) {
- LOG(FATAL) << "Null SSLSessionCacheManager in callback";
- return -1;
- }
- return manager->newSession(ssl, session);
-}
-
-
-int SSLSessionCacheManager::newSession(SSL* ssl, SSL_SESSION* session) {
- string sessionId((char*)session->session_id, session->session_id_length);
- VLOG(4) << "New SSL session; id=" << SSLUtil::hexlify(sessionId);
-
- if (stats_) {
- stats_->recordSSLSession(true /* new session */, false, false);
- }
-
- localCache_->storeSession(sessionId, session, stats_);
-
- if (externalCache_) {
- VLOG(4) << "New SSL session: send session to external cache; id=" <<
- SSLUtil::hexlify(sessionId);
- storeCacheRecord(sessionId, session);
- }
-
- return 1;
-}
-
-void SSLSessionCacheManager::removeSessionCallback(SSL_CTX* ctx,
- SSL_SESSION* session) {
- SSLSessionCacheManager* manager = nullptr;
- manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
-
- if (manager == nullptr) {
- LOG(FATAL) << "Null SSLSessionCacheManager in callback";
- return;
- }
- return manager->removeSession(ctx, session);
-}
-
-void SSLSessionCacheManager::removeSession(SSL_CTX* ctx,
- SSL_SESSION* session) {
- string sessionId((char*)session->session_id, session->session_id_length);
-
- // This hook is only called from SSL when the internal session cache needs to
- // flush sessions. Since we run with the internal cache disabled, this should
- // never be called
- VLOG(3) << "Remove SSL session; id=" << SSLUtil::hexlify(sessionId);
-
- localCache_->removeSession(sessionId);
-
- if (stats_) {
- stats_->recordSSLSessionRemove();
- }
-}
-
-SSL_SESSION* SSLSessionCacheManager::getSessionCallback(SSL* ssl,
- unsigned char* sess_id,
- int id_len,
- int* copyflag) {
- SSLSessionCacheManager* manager = nullptr;
- SSL_CTX* ctx = SSL_get_SSL_CTX(ssl);
- manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
-
- if (manager == nullptr) {
- LOG(FATAL) << "Null SSLSessionCacheManager in callback";
- return nullptr;
- }
- return manager->getSession(ssl, sess_id, id_len, copyflag);
-}
-
-SSL_SESSION* SSLSessionCacheManager::getSession(SSL* ssl,
- unsigned char* session_id,
- int id_len,
- int* copyflag) {
- VLOG(7) << "SSL get session callback";
- SSL_SESSION* session = nullptr;
- bool foreign = false;
- char const* missReason = nullptr;
-
- if (id_len < MIN_SESSION_ID_LENGTH) {
- // We didn't generate this session so it's going to be a miss.
- // This doesn't get logged or counted in the stats.
- return nullptr;
- }
- string sessionId((char*)session_id, id_len);
-
- AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl);
-
- assert(sslSocket != nullptr);
-
- // look it up in the local cache first
- session = localCache_->lookupSession(sessionId);
-#ifdef SSL_SESSION_CB_WOULD_BLOCK
- if (session == nullptr && externalCache_) {
- // external cache might have the session
- foreign = true;
- if (!SSL_want_sess_cache_lookup(ssl)) {
- missReason = "reason: No async cache support;";
- } else {
- PendingLookupMap::iterator pit = pendingLookups_.find(sessionId);
- if (pit == pendingLookups_.end()) {
- auto result = pendingLookups_.emplace(sessionId, PendingLookup());
- // initiate fetch
- VLOG(4) << "Get SSL session [Pending]: Initiate Fetch; fd=" <<
- sslSocket->getFd() << " id=" << SSLUtil::hexlify(sessionId);
- if (lookupCacheRecord(sessionId, sslSocket)) {
- // response is pending
- *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
- return nullptr;
- } else {
- missReason = "reason: failed to send lookup request;";
- pendingLookups_.erase(result.first);
- }
- } else {
- // A lookup was already initiated from this thread
- if (pit->second.request_in_progress) {
- // Someone else initiated the request, attach
- VLOG(4) << "Get SSL session [Pending]: Request in progess: attach; "
- "fd=" << sslSocket->getFd() << " id=" <<
- SSLUtil::hexlify(sessionId);
- std::unique_ptr<DelayedDestruction::DestructorGuard> dg(
- new DelayedDestruction::DestructorGuard(sslSocket));
- pit->second.waiters.push_back(
- std::make_pair(sslSocket, std::move(dg)));
- *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
- return nullptr;
- }
- // request is complete
- session = pit->second.session; // nullptr if our friend didn't have it
- if (session != nullptr) {
- CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
- }
- }
- }
- }
-#endif
-
- bool hit = (session != nullptr);
- if (stats_) {
- stats_->recordSSLSession(false, hit, foreign);
- }
- if (hit) {
- sslSocket->setSessionIDResumed(true);
- }
-
- VLOG(4) << "Get SSL session [" <<
- ((hit) ? "Hit" : "Miss") << "]: " <<
- ((foreign) ? "external" : "local") << " cache; " <<
- ((missReason != nullptr) ? missReason : "") << "fd=" <<
- sslSocket->getFd() << " id=" << SSLUtil::hexlify(sessionId);
-
- // We already bumped the refcount
- *copyflag = 0;
-
- return session;
-}
-
-bool SSLSessionCacheManager::storeCacheRecord(const string& sessionId,
- SSL_SESSION* session) {
- std::string sessionString;
- uint32_t sessionLen = i2d_SSL_SESSION(session, nullptr);
- sessionString.resize(sessionLen);
- uint8_t* cp = (uint8_t *)sessionString.data();
- i2d_SSL_SESSION(session, &cp);
- size_t expiration = SSL_CTX_get_timeout(ctx_->getSSLCtx());
- return externalCache_->setAsync(sessionId, sessionString,
- std::chrono::seconds(expiration));
-}
-
-bool SSLSessionCacheManager::lookupCacheRecord(const string& sessionId,
- AsyncSSLSocket* sslSocket) {
- auto cacheCtx = new SSLCacheProvider::CacheContext();
- cacheCtx->sessionId = sessionId;
- cacheCtx->session = nullptr;
- cacheCtx->sslSocket = sslSocket;
- cacheCtx->guard.reset(
- new DelayedDestruction::DestructorGuard(cacheCtx->sslSocket));
- cacheCtx->manager = this;
- bool res = externalCache_->getAsync(sessionId, cacheCtx);
- if (!res) {
- delete cacheCtx;
- }
- return res;
-}
-
-void SSLSessionCacheManager::restartSSLAccept(
- const SSLCacheProvider::CacheContext* cacheCtx) {
- PendingLookupMap::iterator pit = pendingLookups_.find(cacheCtx->sessionId);
- CHECK(pit != pendingLookups_.end());
- pit->second.request_in_progress = false;
- pit->second.session = cacheCtx->session;
- VLOG(7) << "Restart SSL accept";
- cacheCtx->sslSocket->restartSSLAccept();
- for (const auto& attachedLookup: pit->second.waiters) {
- // Wake up anyone else who was waiting for this session
- VLOG(4) << "Restart SSL accept (waiters) for fd=" <<
- attachedLookup.first->getFd();
- attachedLookup.first->restartSSLAccept();
- }
- pendingLookups_.erase(pit);
-}
-
-void SSLSessionCacheManager::onGetSuccess(
- SSLCacheProvider::CacheContext* cacheCtx,
- const std::string& value) {
- const uint8_t* cp = (uint8_t*)value.data();
- cacheCtx->session = d2i_SSL_SESSION(nullptr, &cp, value.length());
- restartSSLAccept(cacheCtx);
-
- /* Insert in the LRU after restarting all clients. The stats logic
- * in getSession would treat this as a local hit otherwise.
- */
- localCache_->storeSession(cacheCtx->sessionId, cacheCtx->session, stats_);
- delete cacheCtx;
-}
-
-void SSLSessionCacheManager::onGetFailure(
- SSLCacheProvider::CacheContext* cacheCtx) {
- restartSSLAccept(cacheCtx);
- delete cacheCtx;
-}
-
-} // namespace
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-#include <folly/experimental/wangle/ssl/SSLCacheProvider.h>
-#include <folly/experimental/wangle/ssl/SSLStats.h>
-
-#include <folly/EvictingCacheMap.h>
-#include <mutex>
-#include <folly/io/async/AsyncSSLSocket.h>
-
-namespace folly {
-
-class SSLStats;
-
-/**
- * Basic SSL session cache map: Maps session id -> session
- */
-typedef folly::EvictingCacheMap<std::string, SSL_SESSION*> SSLSessionCacheMap;
-
-/**
- * Holds an SSLSessionCacheMap and associated lock
- */
-class LocalSSLSessionCache: private boost::noncopyable {
- public:
- LocalSSLSessionCache(uint32_t maxCacheSize, uint32_t cacheCullSize);
-
- ~LocalSSLSessionCache() {
- std::lock_guard<std::mutex> g(lock);
- // EvictingCacheMap dtor doesn't free values
- sessionCache.clear();
- }
-
- SSLSessionCacheMap sessionCache;
- std::mutex lock;
- uint32_t removedSessions_{0};
-
- private:
-
- void pruneSessionCallback(const std::string& sessionId,
- SSL_SESSION* session);
-};
-
-/**
- * A sharded LRU for SSL sessions. The sharding is inteneded to reduce
- * contention for the LRU locks. Assuming uniform distribution, two workers
- * will contend for the same lock with probability 1 / n_buckets^2.
- */
-class ShardedLocalSSLSessionCache : private boost::noncopyable {
- public:
- ShardedLocalSSLSessionCache(uint32_t n_buckets, uint32_t maxCacheSize,
- uint32_t cacheCullSize) {
- CHECK(n_buckets > 0);
- maxCacheSize = (uint32_t)(((double)maxCacheSize) / n_buckets);
- cacheCullSize = (uint32_t)(((double)cacheCullSize) / n_buckets);
- if (maxCacheSize == 0) {
- maxCacheSize = 1;
- }
- if (cacheCullSize == 0) {
- cacheCullSize = 1;
- }
- for (uint32_t i = 0; i < n_buckets; i++) {
- caches_.push_back(
- std::unique_ptr<LocalSSLSessionCache>(
- new LocalSSLSessionCache(maxCacheSize, cacheCullSize)));
- }
- }
-
- SSL_SESSION* lookupSession(const std::string& sessionId) {
- size_t bucket = hash(sessionId);
- SSL_SESSION* session = nullptr;
- std::lock_guard<std::mutex> g(caches_[bucket]->lock);
-
- auto itr = caches_[bucket]->sessionCache.find(sessionId);
- if (itr != caches_[bucket]->sessionCache.end()) {
- session = itr->second;
- }
-
- if (session) {
- CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
- }
- return session;
- }
-
- void storeSession(const std::string& sessionId, SSL_SESSION* session,
- SSLStats* stats) {
- size_t bucket = hash(sessionId);
- SSL_SESSION* oldSession = nullptr;
- std::lock_guard<std::mutex> g(caches_[bucket]->lock);
-
- auto itr = caches_[bucket]->sessionCache.find(sessionId);
- if (itr != caches_[bucket]->sessionCache.end()) {
- oldSession = itr->second;
- }
-
- if (oldSession) {
- // LRUCacheMap doesn't free on overwrite, so 2x the work for us
- // This can happen in race conditions
- SSL_SESSION_free(oldSession);
- }
- caches_[bucket]->removedSessions_ = 0;
- caches_[bucket]->sessionCache.set(sessionId, session, true);
- if (stats) {
- stats->recordSSLSessionFree(caches_[bucket]->removedSessions_);
- }
- }
-
- void removeSession(const std::string& sessionId) {
- size_t bucket = hash(sessionId);
- std::lock_guard<std::mutex> g(caches_[bucket]->lock);
- caches_[bucket]->sessionCache.erase(sessionId);
- }
-
- private:
-
- /* SSL session IDs are 32 bytes of random data, hash based on first 16 bits */
- size_t hash(const std::string& key) {
- CHECK(key.length() >= 2);
- return (key[0] << 8 | key[1]) % caches_.size();
- }
-
- std::vector< std::unique_ptr<LocalSSLSessionCache> > caches_;
-};
-
-/* A socket/DestructorGuard pair */
-typedef std::pair<AsyncSSLSocket *,
- std::unique_ptr<DelayedDestruction::DestructorGuard>>
- AttachedLookup;
-
-/**
- * PendingLookup structure
- *
- * Keeps track of clients waiting for an SSL session to be retrieved from
- * the external cache provider.
- */
-struct PendingLookup {
- bool request_in_progress;
- SSL_SESSION* session;
- std::list<AttachedLookup> waiters;
-
- PendingLookup() {
- request_in_progress = true;
- session = nullptr;
- }
-};
-
-/* Maps SSL session id to a PendingLookup structure */
-typedef std::map<std::string, PendingLookup> PendingLookupMap;
-
-/**
- * SSLSessionCacheManager handles all stateful session caching. There is an
- * instance of this object per SSL VIP per thread, with a 1:1 correlation with
- * SSL_CTX. The cache can work locally or in concert with an external cache
- * to share sessions across instances.
- *
- * There is a single in memory session cache shared by all VIPs. The cache is
- * split into N buckets (currently 16) with a separate lock per bucket. The
- * VIP ID is hashed and stored as part of the session to handle the
- * (very unlikely) case of session ID collision.
- *
- * When a new SSL session is created, it is added to the LRU cache and
- * sent to the external cache to be stored. The external cache
- * expiration is equal to the SSL session's expiration.
- *
- * When a resume request is received, SSLSessionCacheManager first looks in the
- * local LRU cache for the VIP. If there is a miss there, an asynchronous
- * request for this session is dispatched to the external cache. When the
- * external cache query returns, the LRU cache is updated if the session was
- * found, and the SSL_accept call is resumed.
- *
- * If additional resume requests for the same session ID arrive in the same
- * thread while the request is pending, the 2nd - Nth callers attach to the
- * original external cache requests and are resumed when it comes back. No
- * attempt is made to coalesce external cache requests for the same session
- * ID in different worker threads. Previous work did this, but the
- * complexity was deemed to outweigh the potential savings.
- *
- */
-class SSLSessionCacheManager : private boost::noncopyable {
- public:
- /**
- * Constructor. SSL session related callbacks will be set on the underlying
- * SSL_CTX. vipId is assumed to a unique string identifying the VIP and must
- * be the same on all servers that wish to share sessions via the same
- * external cache.
- */
- SSLSessionCacheManager(
- uint32_t maxCacheSize,
- uint32_t cacheCullSize,
- SSLContext* ctx,
- const folly::SocketAddress& sockaddr,
- const std::string& context,
- EventBase* eventBase,
- SSLStats* stats,
- const std::shared_ptr<SSLCacheProvider>& externalCache);
-
- virtual ~SSLSessionCacheManager();
-
- /**
- * Call this on shutdown to release the global instance of the
- * ShardedLocalSSLSessionCache.
- */
- static void shutdown();
-
- /**
- * Callback for ExternalCache to call when an async get succeeds
- * @param context The context that was passed to the async get request
- * @param value Serialized session
- */
- void onGetSuccess(SSLCacheProvider::CacheContext* context,
- const std::string& value);
-
- /**
- * Callback for ExternalCache to call when an async get fails, either
- * because the requested session is not in the external cache or because
- * of an error.
- * @param context The context that was passed to the async get request
- */
- void onGetFailure(SSLCacheProvider::CacheContext* context);
-
- private:
-
- SSLContext* ctx_;
- std::shared_ptr<ShardedLocalSSLSessionCache> localCache_;
- PendingLookupMap pendingLookups_;
- SSLStats* stats_{nullptr};
- std::shared_ptr<SSLCacheProvider> externalCache_;
-
- /**
- * Invoked by openssl when a new SSL session is created
- */
- int newSession(SSL* ssl, SSL_SESSION* session);
-
- /**
- * Invoked by openssl when an SSL session is ejected from its internal cache.
- * This can't be invoked in the current implementation because SSL's internal
- * caching is disabled.
- */
- void removeSession(SSL_CTX* ctx, SSL_SESSION* session);
-
- /**
- * Invoked by openssl when a client requests a stateful session resumption.
- * Triggers a lookup in our local cache and potentially an asynchronous
- * request to an external cache.
- */
- SSL_SESSION* getSession(SSL* ssl, unsigned char* session_id,
- int id_len, int* copyflag);
-
- /**
- * Store a new session record in the external cache
- */
- bool storeCacheRecord(const std::string& sessionId, SSL_SESSION* session);
-
- /**
- * Lookup a session in the external cache for the specified SSL socket.
- */
- bool lookupCacheRecord(const std::string& sessionId,
- AsyncSSLSocket* sslSock);
-
- /**
- * Restart all clients waiting for the answer to an external cache query
- */
- void restartSSLAccept(const SSLCacheProvider::CacheContext* cacheCtx);
-
- /**
- * Get or create the LRU cache for the given VIP ID
- */
- static std::shared_ptr<ShardedLocalSSLSessionCache> getLocalCache(
- uint32_t maxCacheSize, uint32_t cacheCullSize);
-
- /**
- * static functions registered as callbacks to openssl via
- * SSL_CTX_sess_set_new/get/remove_cb
- */
- static int newSessionCallback(SSL* ssl, SSL_SESSION* session);
- static void removeSessionCallback(SSL_CTX* ctx, SSL_SESSION* session);
- static SSL_SESSION* getSessionCallback(SSL* ssl, unsigned char* session_id,
- int id_len, int* copyflag);
-
- static int32_t sExDataIndex_;
- static std::shared_ptr<ShardedLocalSSLSessionCache> sCache_;
- static std::mutex sCacheLock_;
-};
-
-}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-namespace folly {
-
-class SSLStats {
- public:
- virtual ~SSLStats() noexcept {}
-
- // downstream
- virtual void recordSSLAcceptLatency(int64_t latency) noexcept = 0;
- virtual void recordTLSTicket(bool ticketNew, bool ticketHit) noexcept = 0;
- virtual void recordSSLSession(bool sessionNew, bool sessionHit, bool foreign)
- noexcept = 0;
- virtual void recordSSLSessionRemove() noexcept = 0;
- virtual void recordSSLSessionFree(uint32_t freed) noexcept = 0;
- virtual void recordSSLSessionSetError(uint32_t err) noexcept = 0;
- virtual void recordSSLSessionGetError(uint32_t err) noexcept = 0;
- virtual void recordClientRenegotiation() noexcept = 0;
-
- // upstream
- virtual void recordSSLUpstreamConnection(bool handshake) noexcept = 0;
- virtual void recordSSLUpstreamConnectionError(bool verifyError) noexcept = 0;
- virtual void recordCryptoSSLExternalAttempt() noexcept = 0;
- virtual void recordCryptoSSLExternalConnAlreadyClosed() noexcept = 0;
- virtual void recordCryptoSSLExternalApplicationException() noexcept = 0;
- virtual void recordCryptoSSLExternalSuccess() noexcept = 0;
- virtual void recordCryptoSSLExternalDuration(uint64_t duration) noexcept = 0;
- virtual void recordCryptoSSLLocalAttempt() noexcept = 0;
- virtual void recordCryptoSSLLocalSuccess() noexcept = 0;
-
-};
-
-}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#include <folly/experimental/wangle/ssl/SSLUtil.h>
-
-#include <folly/Memory.h>
-
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL
-#define OPENSSL_GE_101 1
-#include <openssl/asn1.h>
-#include <openssl/x509v3.h>
-#else
-#undef OPENSSL_GE_101
-#endif
-
-namespace folly {
-
-std::mutex SSLUtil::sIndexLock_;
-
-std::unique_ptr<std::string> SSLUtil::getCommonName(const X509* cert) {
- X509_NAME* subject = X509_get_subject_name((X509*)cert);
- if (!subject) {
- return nullptr;
- }
- char cn[ub_common_name + 1];
- int res = X509_NAME_get_text_by_NID(subject, NID_commonName,
- cn, ub_common_name);
- if (res <= 0) {
- return nullptr;
- } else {
- cn[ub_common_name] = '\0';
- return folly::make_unique<std::string>(cn);
- }
-}
-
-std::unique_ptr<std::list<std::string>> SSLUtil::getSubjectAltName(
- const X509* cert) {
-#ifdef OPENSSL_GE_101
- auto nameList = folly::make_unique<std::list<std::string>>();
- GENERAL_NAMES* names = (GENERAL_NAMES*)X509_get_ext_d2i(
- (X509*)cert, NID_subject_alt_name, nullptr, nullptr);
- if (names) {
- auto guard = folly::makeGuard([names] { GENERAL_NAMES_free(names); });
- size_t count = sk_GENERAL_NAME_num(names);
- CHECK(count < std::numeric_limits<int>::max());
- for (int i = 0; i < (int)count; ++i) {
- GENERAL_NAME* generalName = sk_GENERAL_NAME_value(names, i);
- if (generalName->type == GEN_DNS) {
- ASN1_STRING* s = generalName->d.dNSName;
- const char* name = (const char*)ASN1_STRING_data(s);
- // I can't find any docs on what a negative return value here
- // would mean, so I'm going to ignore it.
- auto len = ASN1_STRING_length(s);
- DCHECK(len >= 0);
- if (size_t(len) != strlen(name)) {
- // Null byte(s) in the name; return an error rather than depending on
- // the caller to safely handle this case.
- return nullptr;
- }
- nameList->emplace_back(name);
- }
- }
- }
- return nameList;
-#else
- return nullptr;
-#endif
-}
-
-}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-#include <folly/String.h>
-#include <mutex>
-#include <folly/io/async/AsyncSSLSocket.h>
-
-namespace folly {
-
-/**
- * SSL session establish/resume status
- *
- * changing these values will break logging pipelines
- */
-enum class SSLResumeEnum : uint8_t {
- HANDSHAKE = 0,
- RESUME_SESSION_ID = 1,
- RESUME_TICKET = 3,
- NA = 2
-};
-
-enum class SSLErrorEnum {
- NO_ERROR,
- TIMEOUT,
- DROPPED
-};
-
-class SSLUtil {
- private:
- static std::mutex sIndexLock_;
-
- public:
- /**
- * Ensures only one caller will allocate an ex_data index for a given static
- * or global.
- */
- static void getSSLCtxExIndex(int* pindex) {
- std::lock_guard<std::mutex> g(sIndexLock_);
- if (*pindex < 0) {
- *pindex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
- }
- }
-
- static void getRSAExIndex(int* pindex) {
- std::lock_guard<std::mutex> g(sIndexLock_);
- if (*pindex < 0) {
- *pindex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
- }
- }
-
- static inline std::string hexlify(const std::string& binary) {
- std::string hex;
- folly::hexlify<std::string, std::string>(binary, hex);
-
- return hex;
- }
-
- static inline const std::string& hexlify(const std::string& binary,
- std::string& hex) {
- folly::hexlify<std::string, std::string>(binary, hex);
-
- return hex;
- }
-
- /**
- * Return the SSL resume type for the given socket.
- */
- static inline SSLResumeEnum getResumeState(
- AsyncSSLSocket* sslSocket) {
- return sslSocket->getSSLSessionReused() ?
- (sslSocket->sessionIDResumed() ?
- SSLResumeEnum::RESUME_SESSION_ID :
- SSLResumeEnum::RESUME_TICKET) :
- SSLResumeEnum::HANDSHAKE;
- }
-
- /**
- * Get the Common Name from an X.509 certificate
- * @param cert certificate to inspect
- * @return common name, or null if an error occurs
- */
- static std::unique_ptr<std::string> getCommonName(const X509* cert);
-
- /**
- * Get the Subject Alternative Name value(s) from an X.509 certificate
- * @param cert certificate to inspect
- * @return set of zero or more alternative names, or null if
- * an error occurs
- */
- static std::unique_ptr<std::list<std::string>> getSubjectAltName(
- const X509* cert);
-};
-
-}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#include <folly/experimental/wangle/ssl/TLSTicketKeyManager.h>
-
-#include <folly/experimental/wangle/ssl/SSLStats.h>
-#include <folly/experimental/wangle/ssl/SSLUtil.h>
-
-#include <folly/String.h>
-#include <openssl/aes.h>
-#include <openssl/rand.h>
-#include <openssl/ssl.h>
-#include <folly/io/async/AsyncTimeout.h>
-
-#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB
-using std::string;
-
-namespace {
-
-const int kTLSTicketKeyNameLen = 4;
-const int kTLSTicketKeySaltLen = 12;
-
-}
-
-namespace folly {
-
-
-// TLSTicketKeyManager Implementation
-int32_t TLSTicketKeyManager::sExDataIndex_ = -1;
-
-TLSTicketKeyManager::TLSTicketKeyManager(SSLContext* ctx, SSLStats* stats)
- : ctx_(ctx),
- randState_(0),
- stats_(stats) {
- SSLUtil::getSSLCtxExIndex(&sExDataIndex_);
- SSL_CTX_set_ex_data(ctx_->getSSLCtx(), sExDataIndex_, this);
-}
-
-TLSTicketKeyManager::~TLSTicketKeyManager() {
-}
-
-int
-TLSTicketKeyManager::callback(SSL* ssl, unsigned char* keyName,
- unsigned char* iv,
- EVP_CIPHER_CTX* cipherCtx,
- HMAC_CTX* hmacCtx, int encrypt) {
- TLSTicketKeyManager* manager = nullptr;
- SSL_CTX* ctx = SSL_get_SSL_CTX(ssl);
- manager = (TLSTicketKeyManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
-
- if (manager == nullptr) {
- LOG(FATAL) << "Null TLSTicketKeyManager in callback" ;
- return -1;
- }
- return manager->processTicket(ssl, keyName, iv, cipherCtx, hmacCtx, encrypt);
-}
-
-int
-TLSTicketKeyManager::processTicket(SSL* ssl, unsigned char* keyName,
- unsigned char* iv,
- EVP_CIPHER_CTX* cipherCtx,
- HMAC_CTX* hmacCtx, int encrypt) {
- uint8_t salt[kTLSTicketKeySaltLen];
- uint8_t* saltptr = nullptr;
- uint8_t output[SHA256_DIGEST_LENGTH];
- uint8_t* hmacKey = nullptr;
- uint8_t* aesKey = nullptr;
- TLSTicketKeySource* key = nullptr;
- int result = 0;
-
- if (encrypt) {
- key = findEncryptionKey();
- if (key == nullptr) {
- // no keys available to encrypt
- VLOG(2) << "No TLS ticket key found";
- return -1;
- }
- VLOG(4) << "Encrypting new ticket with key name=" <<
- SSLUtil::hexlify(key->keyName_);
-
- // Get a random salt and write out key name
- RAND_pseudo_bytes(salt, (int)sizeof(salt));
- memcpy(keyName, key->keyName_.data(), kTLSTicketKeyNameLen);
- memcpy(keyName + kTLSTicketKeyNameLen, salt, kTLSTicketKeySaltLen);
-
- // Create the unique keys by hashing with the salt
- makeUniqueKeys(key->keySource_, sizeof(key->keySource_), salt, output);
- // This relies on the fact that SHA256 has 32 bytes of output
- // and that AES-128 keys are 16 bytes
- hmacKey = output;
- aesKey = output + SHA256_DIGEST_LENGTH / 2;
-
- // Initialize iv and cipher/mac CTX
- RAND_pseudo_bytes(iv, AES_BLOCK_SIZE);
- HMAC_Init_ex(hmacCtx, hmacKey, SHA256_DIGEST_LENGTH / 2,
- EVP_sha256(), nullptr);
- EVP_EncryptInit_ex(cipherCtx, EVP_aes_128_cbc(), nullptr, aesKey, iv);
-
- result = 1;
- } else {
- key = findDecryptionKey(keyName);
- if (key == nullptr) {
- // no ticket found for decryption - will issue a new ticket
- if (VLOG_IS_ON(4)) {
- string skeyName((char *)keyName, kTLSTicketKeyNameLen);
- VLOG(4) << "Can't find ticket key with name=" <<
- SSLUtil::hexlify(skeyName)<< ", will generate new ticket";
- }
-
- result = 0;
- } else {
- VLOG(4) << "Decrypting ticket with key name=" <<
- SSLUtil::hexlify(key->keyName_);
-
- // Reconstruct the unique key via the salt
- saltptr = keyName + kTLSTicketKeyNameLen;
- makeUniqueKeys(key->keySource_, sizeof(key->keySource_), saltptr, output);
- hmacKey = output;
- aesKey = output + SHA256_DIGEST_LENGTH / 2;
-
- // Initialize cipher/mac CTX
- HMAC_Init_ex(hmacCtx, hmacKey, SHA256_DIGEST_LENGTH / 2,
- EVP_sha256(), nullptr);
- EVP_DecryptInit_ex(cipherCtx, EVP_aes_128_cbc(), nullptr, aesKey, iv);
-
- result = 1;
- }
- }
- // result records whether a ticket key was found to decrypt this ticket,
- // not wether the session was re-used.
- if (stats_) {
- stats_->recordTLSTicket(encrypt, result);
- }
-
- return result;
-}
-
-bool
-TLSTicketKeyManager::setTLSTicketKeySeeds(
- const std::vector<std::string>& oldSeeds,
- const std::vector<std::string>& currentSeeds,
- const std::vector<std::string>& newSeeds) {
-
- bool result = true;
-
- activeKeys_.clear();
- ticketKeys_.clear();
- ticketSeeds_.clear();
- const std::vector<string> *seedList = &oldSeeds;
- for (uint32_t i = 0; i < 3; i++) {
- TLSTicketSeedType type = (TLSTicketSeedType)i;
- if (type == SEED_CURRENT) {
- seedList = ¤tSeeds;
- } else if (type == SEED_NEW) {
- seedList = &newSeeds;
- }
-
- for (const auto& seedInput: *seedList) {
- TLSTicketSeed* seed = insertSeed(seedInput, type);
- if (seed == nullptr) {
- result = false;
- continue;
- }
- insertNewKey(seed, 1, nullptr);
- }
- }
- if (!result) {
- VLOG(2) << "One or more seeds failed to decode";
- }
-
- if (ticketKeys_.size() == 0 || activeKeys_.size() == 0) {
- LOG(WARNING) << "No keys configured, falling back to default";
- SSL_CTX_set_tlsext_ticket_key_cb(ctx_->getSSLCtx(), nullptr);
- return false;
- }
- SSL_CTX_set_tlsext_ticket_key_cb(ctx_->getSSLCtx(),
- TLSTicketKeyManager::callback);
-
- return true;
-}
-
-string
-TLSTicketKeyManager::makeKeyName(TLSTicketSeed* seed, uint32_t n,
- unsigned char* nameBuf) {
- SHA256_CTX ctx;
-
- SHA256_Init(&ctx);
- SHA256_Update(&ctx, seed->seedName_, sizeof(seed->seedName_));
- SHA256_Update(&ctx, &n, sizeof(n));
- SHA256_Final(nameBuf, &ctx);
- return string((char *)nameBuf, kTLSTicketKeyNameLen);
-}
-
-TLSTicketKeyManager::TLSTicketKeySource*
-TLSTicketKeyManager::insertNewKey(TLSTicketSeed* seed, uint32_t hashCount,
- TLSTicketKeySource* prevKey) {
- unsigned char nameBuf[SHA256_DIGEST_LENGTH];
- std::unique_ptr<TLSTicketKeySource> newKey(new TLSTicketKeySource);
-
- // This function supports hash chaining but it is not currently used.
-
- if (prevKey != nullptr) {
- hashNth(prevKey->keySource_, sizeof(prevKey->keySource_),
- newKey->keySource_, 1);
- } else {
- // can't go backwards or the current is missing, start from the beginning
- hashNth((unsigned char *)seed->seed_.data(), seed->seed_.length(),
- newKey->keySource_, hashCount);
- }
-
- newKey->hashCount_ = hashCount;
- newKey->keyName_ = makeKeyName(seed, hashCount, nameBuf);
- newKey->type_ = seed->type_;
- auto it = ticketKeys_.insert(std::make_pair(newKey->keyName_,
- std::move(newKey)));
-
- auto key = it.first->second.get();
- if (key->type_ == SEED_CURRENT) {
- activeKeys_.push_back(key);
- }
- VLOG(4) << "Adding key for " << hashCount << " type=" <<
- (uint32_t)key->type_ << " Name=" << SSLUtil::hexlify(key->keyName_);
-
- return key;
-}
-
-void
-TLSTicketKeyManager::hashNth(const unsigned char* input, size_t input_len,
- unsigned char* output, uint32_t n) {
- assert(n > 0);
- for (uint32_t i = 0; i < n; i++) {
- SHA256(input, input_len, output);
- input = output;
- input_len = SHA256_DIGEST_LENGTH;
- }
-}
-
-TLSTicketKeyManager::TLSTicketSeed *
-TLSTicketKeyManager::insertSeed(const string& seedInput,
- TLSTicketSeedType type) {
- TLSTicketSeed* seed = nullptr;
- string seedOutput;
-
- if (!folly::unhexlify<string, string>(seedInput, seedOutput)) {
- LOG(WARNING) << "Failed to decode seed type=" << (uint32_t)type <<
- " seed=" << seedInput;
- return seed;
- }
-
- seed = new TLSTicketSeed();
- seed->seed_ = seedOutput;
- seed->type_ = type;
- SHA256((unsigned char *)seedOutput.data(), seedOutput.length(),
- seed->seedName_);
- ticketSeeds_.push_back(std::unique_ptr<TLSTicketSeed>(seed));
-
- return seed;
-}
-
-TLSTicketKeyManager::TLSTicketKeySource *
-TLSTicketKeyManager::findEncryptionKey() {
- TLSTicketKeySource* result = nullptr;
- // call to rand here is a bit hokey since it's not cryptographically
- // random, and is predictably seeded with 0. However, activeKeys_
- // is probably not going to have very many keys in it, and most
- // likely only 1.
- size_t numKeys = activeKeys_.size();
- if (numKeys > 0) {
- result = activeKeys_[rand_r(&randState_) % numKeys];
- }
- return result;
-}
-
-TLSTicketKeyManager::TLSTicketKeySource *
-TLSTicketKeyManager::findDecryptionKey(unsigned char* keyName) {
- string name((char *)keyName, kTLSTicketKeyNameLen);
- TLSTicketKeySource* key = nullptr;
- TLSTicketKeyMap::iterator mapit = ticketKeys_.find(name);
- if (mapit != ticketKeys_.end()) {
- key = mapit->second.get();
- }
- return key;
-}
-
-void
-TLSTicketKeyManager::makeUniqueKeys(unsigned char* parentKey,
- size_t keyLen,
- unsigned char* salt,
- unsigned char* output) {
- SHA256_CTX hash_ctx;
-
- SHA256_Init(&hash_ctx);
- SHA256_Update(&hash_ctx, parentKey, keyLen);
- SHA256_Update(&hash_ctx, salt, kTLSTicketKeySaltLen);
- SHA256_Final(output, &hash_ctx);
-}
-
-} // namespace
-#endif
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-#include <folly/io/async/SSLContext.h>
-#include <folly/io/async/EventBase.h>
-
-namespace folly {
-
-#ifndef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB
-class TLSTicketKeyManager {};
-#else
-class SSLStats;
-/**
- * The TLSTicketKeyManager handles TLS ticket key encryption and decryption in
- * a way that facilitates sharing the ticket keys across a range of servers.
- * Hash chaining is employed to achieve frequent key rotation with minimal
- * configuration change. The scheme is as follows:
- *
- * The manager is supplied with three lists of seeds (old, current and new).
- * The config should be updated with new seeds periodically (e.g., daily).
- * 3 config changes are recommended to achieve the smoothest seed rotation
- * eg:
- * 1. Introduce new seed in the push prior to rotation
- * 2. Rotation push
- * 3. Remove old seeds in the push following rotation
- *
- * Multiple seeds are supported but only a single seed is required.
- *
- * Generating encryption keys from the seed works as follows. For a given
- * seed, hash forward N times where N is currently the constant 1.
- * This is the base key. The name of the base key is the first 4
- * bytes of hash(hash(seed), N). This is copied into the first 4 bytes of the
- * TLS ticket key name field.
- *
- * For each new ticket encryption, the manager generates a random 12 byte salt.
- * Hash the salt and the base key together to form the encryption key for
- * that ticket. The salt is included in the ticket's 'key name' field so it
- * can be used to derive the decryption key. The salt is copied into the second
- * 8 bytes of the TLS ticket key name field.
- *
- * A key is valid for decryption for the lifetime of the instance.
- * Sessions will be valid for less time than that, which results in an extra
- * symmetric decryption to discover the session is expired.
- *
- * A TLSTicketKeyManager should be used in only one thread, and should have
- * a 1:1 relationship with the SSLContext provided.
- *
- */
-class TLSTicketKeyManager : private boost::noncopyable {
- public:
-
- explicit TLSTicketKeyManager(folly::SSLContext* ctx,
- SSLStats* stats);
-
- virtual ~TLSTicketKeyManager();
-
- /**
- * SSL callback to set up encryption/decryption context for a TLS Ticket Key.
- *
- * This will be supplied to the SSL library via
- * SSL_CTX_set_tlsext_ticket_key_cb.
- */
- static int callback(SSL* ssl, unsigned char* keyName,
- unsigned char* iv,
- EVP_CIPHER_CTX* cipherCtx,
- HMAC_CTX* hmacCtx, int encrypt);
-
- /**
- * Initialize the manager with three sets of seeds. There must be at least
- * one current seed, or the manager will revert to the default SSL behavior.
- *
- * @param oldSeeds Seeds previously used which can still decrypt.
- * @param currentSeeds Seeds to use for new ticket encryptions.
- * @param newSeeds Seeds which will be used soon, can be used to decrypt
- * in case some servers in the cluster have already rotated.
- */
- bool setTLSTicketKeySeeds(const std::vector<std::string>& oldSeeds,
- const std::vector<std::string>& currentSeeds,
- const std::vector<std::string>& newSeeds);
-
- private:
- enum TLSTicketSeedType {
- SEED_OLD = 0,
- SEED_CURRENT,
- SEED_NEW
- };
-
- /* The seeds supplied by the configuration */
- struct TLSTicketSeed {
- std::string seed_;
- TLSTicketSeedType type_;
- unsigned char seedName_[SHA256_DIGEST_LENGTH];
- };
-
- struct TLSTicketKeySource {
- int32_t hashCount_;
- std::string keyName_;
- TLSTicketSeedType type_;
- unsigned char keySource_[SHA256_DIGEST_LENGTH];
- };
-
- /**
- * Method to setup encryption/decryption context for a TLS Ticket Key
- *
- * OpenSSL documentation is thin on the return value semantics.
- *
- * For encrypt=1, return < 0 on error, >= 0 for successfully initialized
- * For encrypt=0, return < 0 on error, 0 on key not found
- * 1 on key found, 2 renew_ticket
- *
- * renew_ticket means a new ticket will be issued. We could return this value
- * when receiving a ticket encrypted with a key derived from an OLD seed.
- * However, session_timeout seconds after deploying with a seed
- * rotated from CURRENT -> OLD, there will be no valid tickets outstanding
- * encrypted with the old key. This grace period means no unnecessary
- * handshakes will be performed. If the seed is believed compromised, it
- * should NOT be configured as an OLD seed.
- */
- int processTicket(SSL* ssl, unsigned char* keyName,
- unsigned char* iv,
- EVP_CIPHER_CTX* cipherCtx,
- HMAC_CTX* hmacCtx, int encrypt);
-
- // Creates the name for the nth key generated from seed
- std::string makeKeyName(TLSTicketSeed* seed, uint32_t n,
- unsigned char* nameBuf);
-
- /**
- * Creates the key hashCount hashes from the given seed and inserts it in
- * ticketKeys. A naked pointer to the key is returned for additional
- * processing if needed.
- */
- TLSTicketKeySource* insertNewKey(TLSTicketSeed* seed, uint32_t hashCount,
- TLSTicketKeySource* prevKeySource);
-
- /**
- * hashes input N times placing result in output, which must be at least
- * SHA256_DIGEST_LENGTH long.
- */
- void hashNth(const unsigned char* input, size_t input_len,
- unsigned char* output, uint32_t n);
-
- /**
- * Adds the given seed to the manager
- */
- TLSTicketSeed* insertSeed(const std::string& seedInput,
- TLSTicketSeedType type);
-
- /**
- * Locate a key for encrypting a new ticket
- */
- TLSTicketKeySource* findEncryptionKey();
-
- /**
- * Locate a key for decrypting a ticket with the given keyName
- */
- TLSTicketKeySource* findDecryptionKey(unsigned char* keyName);
-
- /**
- * Derive a unique key from the parent key and the salt via hashing
- */
- void makeUniqueKeys(unsigned char* parentKey, size_t keyLen,
- unsigned char* salt, unsigned char* output);
-
- /**
- * For standalone decryption utility
- */
- friend int decrypt_fb_ticket(folly::TLSTicketKeyManager* manager,
- const std::string& testTicket,
- SSL_SESSION **psess);
-
- typedef std::vector<std::unique_ptr<TLSTicketSeed>> TLSTicketSeedList;
- typedef std::map<std::string, std::unique_ptr<TLSTicketKeySource> >
- TLSTicketKeyMap;
- typedef std::vector<TLSTicketKeySource *> TLSActiveKeyList;
-
- TLSTicketSeedList ticketSeeds_;
- // All key sources that can be used for decryption
- TLSTicketKeyMap ticketKeys_;
- // Key sources that can be used for encryption
- TLSActiveKeyList activeKeys_;
-
- folly::SSLContext* ctx_;
- uint32_t randState_;
- SSLStats* stats_{nullptr};
-
- static int32_t sExDataIndex_;
-};
-#endif
-}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#pragma once
-
-namespace folly {
-
-struct TLSTicketKeySeeds {
- std::vector<std::string> oldSeeds;
- std::vector<std::string> currentSeeds;
- std::vector<std::string> newSeeds;
-};
-
-}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#include <folly/Portability.h>
-#include <folly/io/async/EventBase.h>
-#include <gflags/gflags.h>
-#include <iostream>
-#include <thread>
-#include <folly/io/async/AsyncSSLSocket.h>
-#include <folly/io/async/AsyncSocket.h>
-#include <vector>
-
-using namespace std;
-using namespace folly;
-
-DEFINE_int32(clients, 1, "Number of simulated SSL clients");
-DEFINE_int32(threads, 1, "Number of threads to spread clients across");
-DEFINE_int32(requests, 2, "Total number of requests per client");
-DEFINE_int32(port, 9423, "Server port");
-DEFINE_bool(sticky, false, "A given client sends all reqs to one "
- "(random) server");
-DEFINE_bool(global, false, "All clients in a thread use the same SSL session");
-DEFINE_bool(handshakes, false, "Force 100% handshakes");
-
-string f_servers[10];
-int f_num_servers = 0;
-int tnum = 0;
-
-class ClientRunner {
- public:
-
- ClientRunner(): reqs(0), hits(0), miss(0), num(tnum++) {}
- void run();
-
- int reqs;
- int hits;
- int miss;
- int num;
-};
-
-class SSLCacheClient : public AsyncSocket::ConnectCallback,
- public AsyncSSLSocket::HandshakeCB
-{
-private:
- EventBase* eventBase_;
- int currReq_;
- int serverIdx_;
- AsyncSocket* socket_;
- AsyncSSLSocket* sslSocket_;
- SSL_SESSION* session_;
- SSL_SESSION **pSess_;
- std::shared_ptr<SSLContext> ctx_;
- ClientRunner* cr_;
-
-public:
- SSLCacheClient(EventBase* eventBase, SSL_SESSION **pSess, ClientRunner* cr);
- ~SSLCacheClient() {
- if (session_ && !FLAGS_global)
- SSL_SESSION_free(session_);
- if (socket_ != nullptr) {
- if (sslSocket_ != nullptr) {
- sslSocket_->destroy();
- sslSocket_ = nullptr;
- }
- socket_->destroy();
- socket_ = nullptr;
- }
- };
-
- void start();
-
- virtual void connectSuccess() noexcept;
-
- virtual void connectErr(const AsyncSocketException& ex)
- noexcept ;
-
- virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept;
-
- virtual void handshakeErr(
- AsyncSSLSocket* sock,
- const AsyncSocketException& ex) noexcept;
-
-};
-
-int
-main(int argc, char* argv[])
-{
- gflags::SetUsageMessage(std::string("\n\n"
-"usage: sslcachetest [options] -c <clients> -t <threads> servers\n"
-));
- gflags::ParseCommandLineFlags(&argc, &argv, true);
- int reqs = 0;
- int hits = 0;
- int miss = 0;
- struct timeval start;
- struct timeval end;
- struct timeval result;
-
- srand((unsigned int)time(nullptr));
-
- for (int i = 1; i < argc; i++) {
- f_servers[f_num_servers++] = argv[i];
- }
- if (f_num_servers == 0) {
- cout << "require at least one server\n";
- return 1;
- }
-
- gettimeofday(&start, nullptr);
- if (FLAGS_threads == 1) {
- ClientRunner r;
- r.run();
- gettimeofday(&end, nullptr);
- reqs = r.reqs;
- hits = r.hits;
- miss = r.miss;
- }
- else {
- std::vector<ClientRunner> clients;
- std::vector<std::thread> threads;
- for (int t = 0; t < FLAGS_threads; t++) {
- threads.emplace_back([&] {
- clients[t].run();
- });
- }
- for (auto& thr: threads) {
- thr.join();
- }
- gettimeofday(&end, nullptr);
-
- for (const auto& client: clients) {
- reqs += client.reqs;
- hits += client.hits;
- miss += client.miss;
- }
- }
-
- timersub(&end, &start, &result);
-
- cout << "Requests: " << reqs << endl;
- cout << "Handshakes: " << miss << endl;
- cout << "Resumes: " << hits << endl;
- cout << "Runtime(ms): " << result.tv_sec << "." << result.tv_usec / 1000 <<
- endl;
-
- cout << "ops/sec: " << (reqs * 1.0) /
- ((double)result.tv_sec * 1.0 + (double)result.tv_usec / 1000000.0) << endl;
-
- return 0;
-}
-
-void
-ClientRunner::run()
-{
- EventBase eb;
- std::list<SSLCacheClient *> clients;
- SSL_SESSION* session = nullptr;
-
- for (int i = 0; i < FLAGS_clients; i++) {
- SSLCacheClient* c = new SSLCacheClient(&eb, &session, this);
- c->start();
- clients.push_back(c);
- }
-
- eb.loop();
-
- for (auto it = clients.begin(); it != clients.end(); it++) {
- delete* it;
- }
-
- reqs += hits + miss;
-}
-
-SSLCacheClient::SSLCacheClient(EventBase* eb,
- SSL_SESSION **pSess,
- ClientRunner* cr)
- : eventBase_(eb),
- currReq_(0),
- serverIdx_(0),
- socket_(nullptr),
- sslSocket_(nullptr),
- session_(nullptr),
- pSess_(pSess),
- cr_(cr)
-{
- ctx_.reset(new SSLContext());
- ctx_->setOptions(SSL_OP_NO_TICKET);
-}
-
-void
-SSLCacheClient::start()
-{
- if (currReq_ >= FLAGS_requests) {
- cout << "+";
- return;
- }
-
- if (currReq_ == 0 || !FLAGS_sticky) {
- serverIdx_ = rand() % f_num_servers;
- }
- if (socket_ != nullptr) {
- if (sslSocket_ != nullptr) {
- sslSocket_->destroy();
- sslSocket_ = nullptr;
- }
- socket_->destroy();
- socket_ = nullptr;
- }
- socket_ = new AsyncSocket(eventBase_);
- socket_->connect(this, f_servers[serverIdx_], (uint16_t)FLAGS_port);
-}
-
-void
-SSLCacheClient::connectSuccess() noexcept
-{
- sslSocket_ = new AsyncSSLSocket(ctx_, eventBase_, socket_->detachFd(),
- false);
-
- if (!FLAGS_handshakes) {
- if (session_ != nullptr)
- sslSocket_->setSSLSession(session_);
- else if (FLAGS_global && pSess_ != nullptr)
- sslSocket_->setSSLSession(*pSess_);
- }
- sslSocket_->sslConn(this);
-}
-
-void
-SSLCacheClient::connectErr(const AsyncSocketException& ex)
- noexcept
-{
- cout << "connectError: " << ex.what() << endl;
-}
-
-void
-SSLCacheClient::handshakeSuc(AsyncSSLSocket* socket) noexcept
-{
- if (sslSocket_->getSSLSessionReused()) {
- cr_->hits++;
- } else {
- cr_->miss++;
- if (session_ != nullptr) {
- SSL_SESSION_free(session_);
- }
- session_ = sslSocket_->getSSLSession();
- if (FLAGS_global && pSess_ != nullptr && *pSess_ == nullptr) {
- *pSess_ = session_;
- }
- }
- if ( ((cr_->hits + cr_->miss) % 100) == ((100 / FLAGS_threads) * cr_->num)) {
- cout << ".";
- cout.flush();
- }
- sslSocket_->closeNow();
- currReq_++;
- this->start();
-}
-
-void
-SSLCacheClient::handshakeErr(
- AsyncSSLSocket* sock,
- const AsyncSocketException& ex)
- noexcept
-{
- cout << "handshakeError: " << ex.what() << endl;
-}
+++ /dev/null
-/*
- * Copyright (c) 2014, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
-#include <folly/io/async/EventBase.h>
-#include <folly/io/async/SSLContext.h>
-#include <glog/logging.h>
-#include <gtest/gtest.h>
-#include <folly/experimental/wangle/ssl/SSLContextManager.h>
-#include <folly/experimental/wangle/acceptor/DomainNameMisc.h>
-
-using std::shared_ptr;
-
-namespace folly {
-
-TEST(SSLContextManagerTest, Test1)
-{
- EventBase eventBase;
- SSLContextManager sslCtxMgr(&eventBase, "vip_ssl_context_manager_test_",
- true, nullptr);
- auto www_facebook_com_ctx = std::make_shared<SSLContext>();
- auto start_facebook_com_ctx = std::make_shared<SSLContext>();
- auto start_abc_facebook_com_ctx = std::make_shared<SSLContext>();
-
- sslCtxMgr.insertSSLCtxByDomainName(
- "www.facebook.com",
- strlen("www.facebook.com"),
- www_facebook_com_ctx);
- sslCtxMgr.insertSSLCtxByDomainName(
- "www.facebook.com",
- strlen("www.facebook.com"),
- www_facebook_com_ctx);
- try {
- sslCtxMgr.insertSSLCtxByDomainName(
- "www.facebook.com",
- strlen("www.facebook.com"),
- std::make_shared<SSLContext>());
- } catch (const std::exception& ex) {
- }
- sslCtxMgr.insertSSLCtxByDomainName(
- "*.facebook.com",
- strlen("*.facebook.com"),
- start_facebook_com_ctx);
- sslCtxMgr.insertSSLCtxByDomainName(
- "*.abc.facebook.com",
- strlen("*.abc.facebook.com"),
- start_abc_facebook_com_ctx);
- try {
- sslCtxMgr.insertSSLCtxByDomainName(
- "*.abc.facebook.com",
- strlen("*.abc.facebook.com"),
- std::make_shared<SSLContext>());
- FAIL();
- } catch (const std::exception& ex) {
- }
-
- shared_ptr<SSLContext> retCtx;
- retCtx = sslCtxMgr.getSSLCtx(DNString("www.facebook.com"));
- EXPECT_EQ(retCtx, www_facebook_com_ctx);
- retCtx = sslCtxMgr.getSSLCtx(DNString("WWW.facebook.com"));
- EXPECT_EQ(retCtx, www_facebook_com_ctx);
- EXPECT_FALSE(sslCtxMgr.getSSLCtx(DNString("xyz.facebook.com")));
-
- retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("xyz.facebook.com"));
- EXPECT_EQ(retCtx, start_facebook_com_ctx);
- retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("XYZ.facebook.com"));
- EXPECT_EQ(retCtx, start_facebook_com_ctx);
-
- retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("www.abc.facebook.com"));
- EXPECT_EQ(retCtx, start_abc_facebook_com_ctx);
-
- // ensure "facebook.com" does not match "*.facebook.com"
- EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("facebook.com")));
- // ensure "Xfacebook.com" does not match "*.facebook.com"
- EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("Xfacebook.com")));
- // ensure wildcard name only matches one domain up
- EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("abc.xyz.facebook.com")));
-
- eventBase.loop(); // Clean up events before SSLContextManager is destructed
-}
-
-}
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/wangle/acceptor/Acceptor.h>
+
+#include <folly/wangle/acceptor/ManagedConnection.h>
+#include <folly/wangle/ssl/SSLContextManager.h>
+
+#include <boost/cast.hpp>
+#include <fcntl.h>
+#include <folly/ScopeGuard.h>
+#include <folly/wangle/acceptor/ManagedConnection.h>
+#include <folly/io/async/EventBase.h>
+#include <fstream>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/EventBase.h>
+#include <unistd.h>
+
+using folly::wangle::ConnectionManager;
+using folly::wangle::ManagedConnection;
+using std::chrono::microseconds;
+using std::chrono::milliseconds;
+using std::filebuf;
+using std::ifstream;
+using std::ios;
+using std::shared_ptr;
+using std::string;
+
+namespace folly {
+
+#ifndef NO_LIB_GFLAGS
+DEFINE_int32(shutdown_idle_grace_ms, 5000, "milliseconds to wait before "
+ "closing idle conns");
+#else
+const int32_t FLAGS_shutdown_idle_grace_ms = 5000;
+#endif
+
+static const std::string empty_string;
+std::atomic<uint64_t> Acceptor::totalNumPendingSSLConns_{0};
+
+/**
+ * Lightweight wrapper class to keep track of a newly
+ * accepted connection during SSL handshaking.
+ */
+class AcceptorHandshakeHelper :
+ public AsyncSSLSocket::HandshakeCB,
+ public ManagedConnection {
+ public:
+ AcceptorHandshakeHelper(AsyncSSLSocket::UniquePtr socket,
+ Acceptor* acceptor,
+ const SocketAddress& clientAddr,
+ std::chrono::steady_clock::time_point acceptTime)
+ : socket_(std::move(socket)), acceptor_(acceptor),
+ acceptTime_(acceptTime), clientAddr_(clientAddr) {
+ acceptor_->downstreamConnectionManager_->addConnection(this, true);
+ if(acceptor_->parseClientHello_) {
+ socket_->enableClientHelloParsing();
+ }
+ socket_->sslAccept(this);
+ }
+
+ virtual void timeoutExpired() noexcept {
+ VLOG(4) << "SSL handshake timeout expired";
+ sslError_ = SSLErrorEnum::TIMEOUT;
+ dropConnection();
+ }
+ virtual void describe(std::ostream& os) const {
+ os << "pending handshake on " << clientAddr_;
+ }
+ virtual bool isBusy() const {
+ return true;
+ }
+ virtual void notifyPendingShutdown() {}
+ virtual void closeWhenIdle() {}
+
+ virtual void dropConnection() {
+ VLOG(10) << "Dropping in progress handshake for " << clientAddr_;
+ socket_->closeNow();
+ }
+ virtual void dumpConnectionState(uint8_t loglevel) {
+ }
+
+ private:
+ // AsyncSSLSocket::HandshakeCallback API
+ virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept {
+
+ const unsigned char* nextProto = nullptr;
+ unsigned nextProtoLength = 0;
+ sock->getSelectedNextProtocol(&nextProto, &nextProtoLength);
+ if (VLOG_IS_ON(3)) {
+ if (nextProto) {
+ VLOG(3) << "Client selected next protocol " <<
+ string((const char*)nextProto, nextProtoLength);
+ } else {
+ VLOG(3) << "Client did not select a next protocol";
+ }
+ }
+
+ // fill in SSL-related fields from TransportInfo
+ // the other fields like RTT are filled in the Acceptor
+ TransportInfo tinfo;
+ tinfo.ssl = true;
+ tinfo.acceptTime = acceptTime_;
+ tinfo.sslSetupTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
+ tinfo.sslSetupBytesRead = sock->getRawBytesReceived();
+ tinfo.sslSetupBytesWritten = sock->getRawBytesWritten();
+ tinfo.sslServerName = sock->getSSLServerName();
+ tinfo.sslCipher = sock->getNegotiatedCipherName();
+ tinfo.sslVersion = sock->getSSLVersion();
+ tinfo.sslCertSize = sock->getSSLCertSize();
+ tinfo.sslResume = SSLUtil::getResumeState(sock);
+ sock->getSSLClientCiphers(tinfo.sslClientCiphers);
+ sock->getSSLServerCiphers(tinfo.sslServerCiphers);
+ tinfo.sslClientComprMethods = sock->getSSLClientComprMethods();
+ tinfo.sslClientExts = sock->getSSLClientExts();
+ tinfo.sslNextProtocol.assign(
+ reinterpret_cast<const char*>(nextProto),
+ nextProtoLength);
+
+ acceptor_->updateSSLStats(sock, tinfo.sslSetupTime, SSLErrorEnum::NO_ERROR);
+ acceptor_->downstreamConnectionManager_->removeConnection(this);
+ acceptor_->sslConnectionReady(std::move(socket_), clientAddr_,
+ nextProto ? string((const char*)nextProto, nextProtoLength) :
+ empty_string, tinfo);
+ delete this;
+ }
+
+ virtual void handshakeErr(AsyncSSLSocket* sock,
+ const AsyncSocketException& ex) noexcept {
+ auto elapsedTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
+ VLOG(3) << "SSL handshake error after " << elapsedTime.count() <<
+ " ms; " << sock->getRawBytesReceived() << " bytes received & " <<
+ sock->getRawBytesWritten() << " bytes sent: " <<
+ ex.what();
+ acceptor_->updateSSLStats(sock, elapsedTime, sslError_);
+ acceptor_->sslConnectionError();
+ delete this;
+ }
+
+ AsyncSSLSocket::UniquePtr socket_;
+ Acceptor* acceptor_;
+ std::chrono::steady_clock::time_point acceptTime_;
+ SocketAddress clientAddr_;
+ SSLErrorEnum sslError_{SSLErrorEnum::NO_ERROR};
+};
+
+Acceptor::Acceptor(const ServerSocketConfig& accConfig) :
+ accConfig_(accConfig),
+ socketOptions_(accConfig.getSocketOptions()) {
+}
+
+void
+Acceptor::init(AsyncServerSocket* serverSocket,
+ EventBase* eventBase) {
+ CHECK(nullptr == this->base_);
+
+ if (accConfig_.isSSL()) {
+ if (!sslCtxManager_) {
+ sslCtxManager_ = folly::make_unique<SSLContextManager>(
+ eventBase,
+ "vip_" + getName(),
+ accConfig_.strictSSL, nullptr);
+ }
+ for (const auto& sslCtxConfig : accConfig_.sslContextConfigs) {
+ sslCtxManager_->addSSLContextConfig(
+ sslCtxConfig,
+ accConfig_.sslCacheOptions,
+ &accConfig_.initialTicketSeeds,
+ accConfig_.bindAddress,
+ cacheProvider_);
+ parseClientHello_ |= sslCtxConfig.clientHelloParsingEnabled;
+ }
+
+ CHECK(sslCtxManager_->getDefaultSSLCtx());
+ }
+
+ base_ = eventBase;
+ state_ = State::kRunning;
+ downstreamConnectionManager_ = ConnectionManager::makeUnique(
+ eventBase, accConfig_.connectionIdleTimeout, this);
+
+ if (serverSocket) {
+ serverSocket->addAcceptCallback(this, eventBase);
+ // SO_KEEPALIVE is the only setting that is inherited by accepted
+ // connections so only apply this setting
+ for (const auto& option: socketOptions_) {
+ if (option.first.level == SOL_SOCKET &&
+ option.first.optname == SO_KEEPALIVE && option.second == 1) {
+ serverSocket->setKeepAliveEnabled(true);
+ break;
+ }
+ }
+ }
+}
+
+Acceptor::~Acceptor(void) {
+}
+
+void Acceptor::addSSLContextConfig(const SSLContextConfig& sslCtxConfig) {
+ sslCtxManager_->addSSLContextConfig(sslCtxConfig,
+ accConfig_.sslCacheOptions,
+ &accConfig_.initialTicketSeeds,
+ accConfig_.bindAddress,
+ cacheProvider_);
+}
+
+void
+Acceptor::drainAllConnections() {
+ if (downstreamConnectionManager_) {
+ downstreamConnectionManager_->initiateGracefulShutdown(
+ std::chrono::milliseconds(FLAGS_shutdown_idle_grace_ms));
+ }
+}
+
+void Acceptor::setLoadShedConfig(const LoadShedConfiguration& from,
+ IConnectionCounter* counter) {
+ loadShedConfig_ = from;
+ connectionCounter_ = counter;
+}
+
+bool Acceptor::canAccept(const SocketAddress& address) {
+ if (!connectionCounter_) {
+ return true;
+ }
+
+ uint64_t maxConnections = connectionCounter_->getMaxConnections();
+ if (maxConnections == 0) {
+ return true;
+ }
+
+ uint64_t currentConnections = connectionCounter_->getNumConnections();
+ if (currentConnections < maxConnections) {
+ return true;
+ }
+
+ if (loadShedConfig_.isWhitelisted(address)) {
+ return true;
+ }
+
+ // Take care of comparing connection count against max connections across
+ // all acceptors. Expensive since a lock must be taken to get the counter.
+ auto connectionCountForLoadShedding = getConnectionCountForLoadShedding();
+ if (connectionCountForLoadShedding < loadShedConfig_.getMaxConnections()) {
+ return true;
+ }
+
+ VLOG(4) << address.describe() << " not whitelisted";
+ return false;
+}
+
+void
+Acceptor::connectionAccepted(
+ int fd, const SocketAddress& clientAddr) noexcept {
+ if (!canAccept(clientAddr)) {
+ close(fd);
+ return;
+ }
+ auto acceptTime = std::chrono::steady_clock::now();
+ for (const auto& opt: socketOptions_) {
+ opt.first.apply(fd, opt.second);
+ }
+
+ onDoneAcceptingConnection(fd, clientAddr, acceptTime);
+}
+
+void Acceptor::onDoneAcceptingConnection(
+ int fd,
+ const SocketAddress& clientAddr,
+ std::chrono::steady_clock::time_point acceptTime) noexcept {
+ processEstablishedConnection(fd, clientAddr, acceptTime);
+}
+
+void
+Acceptor::processEstablishedConnection(
+ int fd,
+ const SocketAddress& clientAddr,
+ std::chrono::steady_clock::time_point acceptTime) noexcept {
+ if (accConfig_.isSSL()) {
+ CHECK(sslCtxManager_);
+ AsyncSSLSocket::UniquePtr sslSock(
+ makeNewAsyncSSLSocket(
+ sslCtxManager_->getDefaultSSLCtx(), base_, fd));
+ ++numPendingSSLConns_;
+ ++totalNumPendingSSLConns_;
+ if (totalNumPendingSSLConns_ > accConfig_.maxConcurrentSSLHandshakes) {
+ VLOG(2) << "dropped SSL handshake on " << accConfig_.name <<
+ " too many handshakes in progress";
+ updateSSLStats(sslSock.get(), std::chrono::milliseconds(0),
+ SSLErrorEnum::DROPPED);
+ sslConnectionError();
+ return;
+ }
+ new AcceptorHandshakeHelper(
+ std::move(sslSock), this, clientAddr, acceptTime);
+ } else {
+ TransportInfo tinfo;
+ tinfo.ssl = false;
+ tinfo.acceptTime = acceptTime;
+ AsyncSocket::UniquePtr sock(makeNewAsyncSocket(base_, fd));
+ connectionReady(std::move(sock), clientAddr, empty_string, tinfo);
+ }
+}
+
+void
+Acceptor::connectionReady(
+ AsyncSocket::UniquePtr sock,
+ const SocketAddress& clientAddr,
+ const string& nextProtocolName,
+ TransportInfo& tinfo) {
+ // Limit the number of reads from the socket per poll loop iteration,
+ // both to keep memory usage under control and to prevent one fast-
+ // writing client from starving other connections.
+ sock->setMaxReadsPerEvent(16);
+ tinfo.initWithSocket(sock.get());
+ onNewConnection(std::move(sock), &clientAddr, nextProtocolName, tinfo);
+}
+
+void
+Acceptor::sslConnectionReady(AsyncSocket::UniquePtr sock,
+ const SocketAddress& clientAddr,
+ const string& nextProtocol,
+ TransportInfo& tinfo) {
+ CHECK(numPendingSSLConns_ > 0);
+ connectionReady(std::move(sock), clientAddr, nextProtocol, tinfo);
+ --numPendingSSLConns_;
+ --totalNumPendingSSLConns_;
+ if (state_ == State::kDraining) {
+ checkDrained();
+ }
+}
+
+void
+Acceptor::sslConnectionError() {
+ CHECK(numPendingSSLConns_ > 0);
+ --numPendingSSLConns_;
+ --totalNumPendingSSLConns_;
+ if (state_ == State::kDraining) {
+ checkDrained();
+ }
+}
+
+void
+Acceptor::acceptError(const std::exception& ex) noexcept {
+ // An error occurred.
+ // The most likely error is out of FDs. AsyncServerSocket will back off
+ // briefly if we are out of FDs, then continue accepting later.
+ // Just log a message here.
+ LOG(ERROR) << "error accepting on acceptor socket: " << ex.what();
+}
+
+void
+Acceptor::acceptStopped() noexcept {
+ VLOG(3) << "Acceptor " << this << " acceptStopped()";
+ // Drain the open client connections
+ drainAllConnections();
+
+ // If we haven't yet finished draining, begin doing so by marking ourselves
+ // as in the draining state. We must be sure to hit checkDrained() here, as
+ // if we're completely idle, we can should consider ourself drained
+ // immediately (as there is no outstanding work to complete to cause us to
+ // re-evaluate this).
+ if (state_ != State::kDone) {
+ state_ = State::kDraining;
+ checkDrained();
+ }
+}
+
+void
+Acceptor::onEmpty(const ConnectionManager& cm) {
+ VLOG(3) << "Acceptor=" << this << " onEmpty()";
+ if (state_ == State::kDraining) {
+ checkDrained();
+ }
+}
+
+void
+Acceptor::checkDrained() {
+ CHECK(state_ == State::kDraining);
+ if (forceShutdownInProgress_ ||
+ (downstreamConnectionManager_->getNumConnections() != 0) ||
+ (numPendingSSLConns_ != 0)) {
+ return;
+ }
+
+ VLOG(2) << "All connections drained from Acceptor=" << this << " in thread "
+ << base_;
+
+ downstreamConnectionManager_.reset();
+
+ state_ = State::kDone;
+
+ onConnectionsDrained();
+}
+
+milliseconds
+Acceptor::getConnTimeout() const {
+ return accConfig_.connectionIdleTimeout;
+}
+
+void Acceptor::addConnection(ManagedConnection* conn) {
+ // Add the socket to the timeout manager so that it can be cleaned
+ // up after being left idle for a long time.
+ downstreamConnectionManager_->addConnection(conn, true);
+}
+
+void
+Acceptor::forceStop() {
+ base_->runInEventBaseThread([&] { dropAllConnections(); });
+}
+
+void
+Acceptor::dropAllConnections() {
+ if (downstreamConnectionManager_) {
+ VLOG(3) << "Dropping all connections from Acceptor=" << this <<
+ " in thread " << base_;
+ assert(base_->isInEventBaseThread());
+ forceShutdownInProgress_ = true;
+ downstreamConnectionManager_->dropAllConnections();
+ CHECK(downstreamConnectionManager_->getNumConnections() == 0);
+ downstreamConnectionManager_.reset();
+ }
+ CHECK(numPendingSSLConns_ == 0);
+
+ state_ = State::kDone;
+ onConnectionsDrained();
+}
+
+} // namespace
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include "folly/wangle/acceptor/ServerSocketConfig.h"
+#include "folly/wangle/acceptor/ConnectionCounter.h"
+#include <folly/wangle/acceptor/ConnectionManager.h>
+#include "folly/wangle/acceptor/LoadShedConfiguration.h"
+#include "folly/wangle/ssl/SSLCacheProvider.h"
+#include "folly/wangle/acceptor/TransportInfo.h"
+
+#include <chrono>
+#include <event.h>
+#include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncServerSocket.h>
+
+namespace folly { namespace wangle {
+class ManagedConnection;
+}}
+
+namespace folly {
+
+class SocketAddress;
+class SSLContext;
+class AsyncTransport;
+class SSLContextManager;
+
+/**
+ * An abstract acceptor for TCP-based network services.
+ *
+ * There is one acceptor object per thread for each listening socket. When a
+ * new connection arrives on the listening socket, it is accepted by one of the
+ * acceptor objects. From that point on the connection will be processed by
+ * that acceptor's thread.
+ *
+ * The acceptor will call the abstract onNewConnection() method to create
+ * a new ManagedConnection object for each accepted socket. The acceptor
+ * also tracks all outstanding connections that it has accepted.
+ */
+class Acceptor :
+ public folly::AsyncServerSocket::AcceptCallback,
+ public folly::wangle::ConnectionManager::Callback {
+ public:
+
+ enum class State : uint32_t {
+ kInit, // not yet started
+ kRunning, // processing requests normally
+ kDraining, // processing outstanding conns, but not accepting new ones
+ kDone, // no longer accepting, and all connections finished
+ };
+
+ explicit Acceptor(const ServerSocketConfig& accConfig);
+ virtual ~Acceptor();
+
+ /**
+ * Supply an SSL cache provider
+ * @note Call this before init()
+ */
+ virtual void setSSLCacheProvider(
+ const std::shared_ptr<SSLCacheProvider>& cacheProvider) {
+ cacheProvider_ = cacheProvider;
+ }
+
+ /**
+ * Initialize the Acceptor to run in the specified EventBase
+ * thread, receiving connections from the specified AsyncServerSocket.
+ *
+ * This method will be called from the AsyncServerSocket's primary thread,
+ * not the specified EventBase thread.
+ */
+ virtual void init(AsyncServerSocket* serverSocket,
+ EventBase* eventBase);
+
+ /**
+ * Dynamically add a new SSLContextConfig
+ */
+ void addSSLContextConfig(const SSLContextConfig& sslCtxConfig);
+
+ SSLContextManager* getSSLContextManager() const {
+ return sslCtxManager_.get();
+ }
+
+ /**
+ * Return the number of outstanding connections in this service instance.
+ */
+ uint32_t getNumConnections() const {
+ return downstreamConnectionManager_ ?
+ downstreamConnectionManager_->getNumConnections() : 0;
+ }
+
+ /**
+ * Access the Acceptor's event base.
+ */
+ virtual EventBase* getEventBase() const { return base_; }
+
+ /**
+ * Access the Acceptor's downstream (client-side) ConnectionManager
+ */
+ virtual folly::wangle::ConnectionManager* getConnectionManager() {
+ return downstreamConnectionManager_.get();
+ }
+
+ /**
+ * Invoked when a new ManagedConnection is created.
+ *
+ * This allows the Acceptor to track the outstanding connections,
+ * for tracking timeouts and for ensuring that all connections have been
+ * drained on shutdown.
+ */
+ void addConnection(folly::wangle::ManagedConnection* connection);
+
+ /**
+ * Get this acceptor's current state.
+ */
+ State getState() const {
+ return state_;
+ }
+
+ /**
+ * Get the current connection timeout.
+ */
+ std::chrono::milliseconds getConnTimeout() const;
+
+ /**
+ * Returns the name of this VIP.
+ *
+ * Will return an empty string if no name has been configured.
+ */
+ const std::string& getName() const {
+ return accConfig_.name;
+ }
+
+ /**
+ * Force the acceptor to drop all connections and stop processing.
+ *
+ * This function may be called from any thread. The acceptor will not
+ * necessarily stop before this function returns: the stop will be scheduled
+ * to run in the acceptor's thread.
+ */
+ virtual void forceStop();
+
+ bool isSSL() const { return accConfig_.isSSL(); }
+
+ const ServerSocketConfig& getConfig() const { return accConfig_; }
+
+ static uint64_t getTotalNumPendingSSLConns() {
+ return totalNumPendingSSLConns_.load();
+ }
+
+ /**
+ * Called right when the TCP connection has been accepted, before processing
+ * the first HTTP bytes (HTTP) or the SSL handshake (HTTPS)
+ */
+ virtual void onDoneAcceptingConnection(
+ int fd,
+ const SocketAddress& clientAddr,
+ std::chrono::steady_clock::time_point acceptTime
+ ) noexcept;
+
+ /**
+ * Begins either processing HTTP bytes (HTTP) or the SSL handshake (HTTPS)
+ */
+ void processEstablishedConnection(
+ int fd,
+ const SocketAddress& clientAddr,
+ std::chrono::steady_clock::time_point acceptTime
+ ) noexcept;
+
+ /**
+ * Drains all open connections of their outstanding transactions. When
+ * a connection's transaction count reaches zero, the connection closes.
+ */
+ void drainAllConnections();
+
+ protected:
+ friend class AcceptorHandshakeHelper;
+
+ /**
+ * Our event loop.
+ *
+ * Probably needs to be used to pass to a ManagedConnection
+ * implementation. Also visible in case a subclass wishes to do additional
+ * things w/ the event loop (e.g. in attach()).
+ */
+ EventBase* base_{nullptr};
+
+ virtual uint64_t getConnectionCountForLoadShedding(void) const { return 0; }
+
+ /**
+ * Hook for subclasses to drop newly accepted connections prior
+ * to handshaking.
+ */
+ virtual bool canAccept(const folly::SocketAddress&);
+
+ /**
+ * Invoked when a new connection is created. This is where application starts
+ * processing a new downstream connection.
+ *
+ * NOTE: Application should add the new connection to
+ * downstreamConnectionManager so that it can be garbage collected after
+ * certain period of idleness.
+ *
+ * @param sock the socket connected to the client
+ * @param address the address of the client
+ * @param nextProtocolName the name of the L6 or L7 protocol to be
+ * spoken on the connection, if known (e.g.,
+ * from TLS NPN during secure connection setup),
+ * or an empty string if unknown
+ */
+ virtual void onNewConnection(
+ AsyncSocket::UniquePtr sock,
+ const folly::SocketAddress* address,
+ const std::string& nextProtocolName,
+ const TransportInfo& tinfo) = 0;
+
+ virtual AsyncSocket::UniquePtr makeNewAsyncSocket(EventBase* base, int fd) {
+ return AsyncSocket::UniquePtr(new AsyncSocket(base, fd));
+ }
+
+ virtual AsyncSSLSocket::UniquePtr makeNewAsyncSSLSocket(
+ const std::shared_ptr<SSLContext>& ctx, EventBase* base, int fd) {
+ return AsyncSSLSocket::UniquePtr(new AsyncSSLSocket(ctx, base, fd));
+ }
+
+ /**
+ * Hook for subclasses to record stats about SSL connection establishment.
+ */
+ virtual void updateSSLStats(
+ const AsyncSSLSocket* sock,
+ std::chrono::milliseconds acceptLatency,
+ SSLErrorEnum error) noexcept {}
+
+ /**
+ * Drop all connections.
+ *
+ * forceStop() schedules dropAllConnections() to be called in the acceptor's
+ * thread.
+ */
+ void dropAllConnections();
+
+ protected:
+
+ /**
+ * onConnectionsDrained() will be called once all connections have been
+ * drained while the acceptor is stopping.
+ *
+ * Subclasses can override this method to perform any subclass-specific
+ * cleanup.
+ */
+ virtual void onConnectionsDrained() {}
+
+ // AsyncServerSocket::AcceptCallback methods
+ void connectionAccepted(int fd,
+ const folly::SocketAddress& clientAddr)
+ noexcept;
+ void acceptError(const std::exception& ex) noexcept;
+ void acceptStopped() noexcept;
+
+ // ConnectionManager::Callback methods
+ void onEmpty(const folly::wangle::ConnectionManager& cm);
+ void onConnectionAdded(const folly::wangle::ConnectionManager& cm) {}
+ void onConnectionRemoved(const folly::wangle::ConnectionManager& cm) {}
+
+ /**
+ * Process a connection that is to ready to receive L7 traffic.
+ * This method is called immediately upon accept for plaintext
+ * connections and upon completion of SSL handshaking or resumption
+ * for SSL connections.
+ */
+ void connectionReady(
+ AsyncSocket::UniquePtr sock,
+ const folly::SocketAddress& clientAddr,
+ const std::string& nextProtocolName,
+ TransportInfo& tinfo);
+
+ const LoadShedConfiguration& getLoadShedConfiguration() const {
+ return loadShedConfig_;
+ }
+
+ protected:
+ const ServerSocketConfig accConfig_;
+ void setLoadShedConfig(const LoadShedConfiguration& from,
+ IConnectionCounter* counter);
+
+ /**
+ * Socket options to apply to the client socket
+ */
+ AsyncSocket::OptionMap socketOptions_;
+
+ std::unique_ptr<SSLContextManager> sslCtxManager_;
+
+ /**
+ * Whether we want to enable client hello parsing in the handshake helper
+ * to get list of supported client ciphers.
+ */
+ bool parseClientHello_{false};
+
+ folly::wangle::ConnectionManager::UniquePtr downstreamConnectionManager_;
+
+ private:
+
+ // Forbidden copy constructor and assignment opererator
+ Acceptor(Acceptor const &) = delete;
+ Acceptor& operator=(Acceptor const &) = delete;
+
+ /**
+ * Wrapper for connectionReady() that decrements the count of
+ * pending SSL connections.
+ */
+ void sslConnectionReady(AsyncSocket::UniquePtr sock,
+ const folly::SocketAddress& clientAddr,
+ const std::string& nextProtocol,
+ TransportInfo& tinfo);
+
+ /**
+ * Notification callback for SSL handshake failures.
+ */
+ void sslConnectionError();
+
+ void checkDrained();
+
+ State state_{State::kInit};
+ uint64_t numPendingSSLConns_{0};
+
+ static std::atomic<uint64_t> totalNumPendingSSLConns_;
+
+ bool forceShutdownInProgress_{false};
+ LoadShedConfiguration loadShedConfig_;
+ IConnectionCounter* connectionCounter_{nullptr};
+ std::shared_ptr<SSLCacheProvider> cacheProvider_;
+};
+
+class AcceptorFactory {
+ public:
+ virtual std::shared_ptr<Acceptor> newAcceptor() = 0;
+ virtual ~AcceptorFactory() = default;
+};
+
+} // namespace
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+namespace folly {
+
+class IConnectionCounter {
+ public:
+ virtual uint64_t getNumConnections() const = 0;
+
+ /**
+ * Get the maximum number of non-whitelisted client-side connections
+ * across all Acceptors managed by this. A value
+ * of zero means "unlimited."
+ */
+ virtual uint64_t getMaxConnections() const = 0;
+
+ /**
+ * Increment the count of client-side connections.
+ */
+ virtual void onConnectionAdded() = 0;
+
+ /**
+ * Decrement the count of client-side connections.
+ */
+ virtual void onConnectionRemoved() = 0;
+ virtual ~IConnectionCounter() {}
+};
+
+class SimpleConnectionCounter: public IConnectionCounter {
+ public:
+ uint64_t getNumConnections() const override { return numConnections_; }
+ uint64_t getMaxConnections() const override { return maxConnections_; }
+ void setMaxConnections(uint64_t maxConnections) {
+ maxConnections_ = maxConnections;
+ }
+
+ void onConnectionAdded() override { numConnections_++; }
+ void onConnectionRemoved() override { numConnections_--; }
+ virtual ~SimpleConnectionCounter() {}
+
+ protected:
+ uint64_t maxConnections_{0};
+ uint64_t numConnections_{0};
+};
+
+}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <folly/wangle/acceptor/ConnectionManager.h>
+
+#include <glog/logging.h>
+#include <folly/io/async/EventBase.h>
+
+using folly::HHWheelTimer;
+using std::chrono::milliseconds;
+
+namespace folly { namespace wangle {
+
+ConnectionManager::ConnectionManager(EventBase* eventBase,
+ milliseconds timeout, Callback* callback)
+ : connTimeouts_(new HHWheelTimer(eventBase)),
+ callback_(callback),
+ eventBase_(eventBase),
+ idleIterator_(conns_.end()),
+ idleLoopCallback_(this),
+ timeout_(timeout) {
+
+}
+
+void
+ConnectionManager::addConnection(ManagedConnection* connection,
+ bool timeout) {
+ CHECK_NOTNULL(connection);
+ ConnectionManager* oldMgr = connection->getConnectionManager();
+ if (oldMgr != this) {
+ if (oldMgr) {
+ // 'connection' was being previously managed in a different thread.
+ // We must remove it from that manager before adding it to this one.
+ oldMgr->removeConnection(connection);
+ }
+ conns_.push_back(*connection);
+ connection->setConnectionManager(this);
+ if (callback_) {
+ callback_->onConnectionAdded(*this);
+ }
+ }
+ if (timeout) {
+ scheduleTimeout(connection);
+ }
+}
+
+void
+ConnectionManager::scheduleTimeout(ManagedConnection* connection) {
+ if (timeout_ > std::chrono::milliseconds(0)) {
+ connTimeouts_->scheduleTimeout(connection, timeout_);
+ }
+}
+
+void ConnectionManager::scheduleTimeout(
+ folly::HHWheelTimer::Callback* callback,
+ std::chrono::milliseconds timeout) {
+ connTimeouts_->scheduleTimeout(callback, timeout);
+}
+
+void
+ConnectionManager::removeConnection(ManagedConnection* connection) {
+ if (connection->getConnectionManager() == this) {
+ connection->cancelTimeout();
+ connection->setConnectionManager(nullptr);
+
+ // Un-link the connection from our list, being careful to keep the iterator
+ // that we're using for idle shedding valid
+ auto it = conns_.iterator_to(*connection);
+ if (it == idleIterator_) {
+ ++idleIterator_;
+ }
+ conns_.erase(it);
+
+ if (callback_) {
+ callback_->onConnectionRemoved(*this);
+ if (getNumConnections() == 0) {
+ callback_->onEmpty(*this);
+ }
+ }
+ }
+}
+
+void
+ConnectionManager::initiateGracefulShutdown(
+ std::chrono::milliseconds idleGrace) {
+ if (idleGrace.count() > 0) {
+ idleLoopCallback_.scheduleTimeout(idleGrace);
+ VLOG(3) << "Scheduling idle grace period of " << idleGrace.count() << "ms";
+ } else {
+ action_ = ShutdownAction::DRAIN2;
+ VLOG(3) << "proceeding directly to closing idle connections";
+ }
+ drainAllConnections();
+}
+
+void
+ConnectionManager::drainAllConnections() {
+ DestructorGuard g(this);
+ size_t numCleared = 0;
+ size_t numKept = 0;
+
+ auto it = idleIterator_ == conns_.end() ?
+ conns_.begin() : idleIterator_;
+
+ while (it != conns_.end() && (numKept + numCleared) < 64) {
+ ManagedConnection& conn = *it++;
+ if (action_ == ShutdownAction::DRAIN1) {
+ conn.notifyPendingShutdown();
+ } else {
+ // Second time around: close idle sessions. If they aren't idle yet,
+ // have them close when they are idle
+ if (conn.isBusy()) {
+ numKept++;
+ } else {
+ numCleared++;
+ }
+ conn.closeWhenIdle();
+ }
+ }
+
+ if (action_ == ShutdownAction::DRAIN2) {
+ VLOG(2) << "Idle connections cleared: " << numCleared <<
+ ", busy conns kept: " << numKept;
+ }
+ if (it != conns_.end()) {
+ idleIterator_ = it;
+ eventBase_->runInLoop(&idleLoopCallback_);
+ } else {
+ action_ = ShutdownAction::DRAIN2;
+ }
+}
+
+void
+ConnectionManager::dropAllConnections() {
+ DestructorGuard g(this);
+
+ // Iterate through our connection list, and drop each connection.
+ VLOG(3) << "connections to drop: " << conns_.size();
+ idleLoopCallback_.cancelTimeout();
+ unsigned i = 0;
+ while (!conns_.empty()) {
+ ManagedConnection& conn = conns_.front();
+ conns_.pop_front();
+ conn.cancelTimeout();
+ conn.setConnectionManager(nullptr);
+ // For debugging purposes, dump information about the first few
+ // connections.
+ static const unsigned MAX_CONNS_TO_DUMP = 2;
+ if (++i <= MAX_CONNS_TO_DUMP) {
+ conn.dumpConnectionState(3);
+ }
+ conn.dropConnection();
+ }
+ idleIterator_ = conns_.end();
+ idleLoopCallback_.cancelLoopCallback();
+
+ if (callback_) {
+ callback_->onEmpty(*this);
+ }
+}
+
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <folly/wangle/acceptor/ManagedConnection.h>
+
+#include <chrono>
+#include <folly/Memory.h>
+#include <folly/io/async/AsyncTimeout.h>
+#include <folly/io/async/HHWheelTimer.h>
+#include <folly/io/async/DelayedDestruction.h>
+#include <folly/io/async/EventBase.h>
+
+namespace folly { namespace wangle {
+
+/**
+ * A ConnectionManager keeps track of ManagedConnections.
+ */
+class ConnectionManager: public folly::DelayedDestruction {
+ public:
+
+ /**
+ * Interface for an optional observer that's notified about
+ * various events in a ConnectionManager
+ */
+ class Callback {
+ public:
+ virtual ~Callback() {}
+
+ /**
+ * Invoked when the number of connections managed by the
+ * ConnectionManager changes from nonzero to zero.
+ */
+ virtual void onEmpty(const ConnectionManager& cm) = 0;
+
+ /**
+ * Invoked when a connection is added to the ConnectionManager.
+ */
+ virtual void onConnectionAdded(const ConnectionManager& cm) = 0;
+
+ /**
+ * Invoked when a connection is removed from the ConnectionManager.
+ */
+ virtual void onConnectionRemoved(const ConnectionManager& cm) = 0;
+ };
+
+ typedef std::unique_ptr<ConnectionManager, Destructor> UniquePtr;
+
+ /**
+ * Returns a new instance of ConnectionManager wrapped in a unique_ptr
+ */
+ template<typename... Args>
+ static UniquePtr makeUnique(Args&&... args) {
+ return folly::make_unique<ConnectionManager, Destructor>(
+ std::forward<Args>(args)...);
+ }
+
+ /**
+ * Constructor not to be used by itself.
+ */
+ ConnectionManager(folly::EventBase* eventBase,
+ std::chrono::milliseconds timeout,
+ Callback* callback = nullptr);
+
+ /**
+ * Add a connection to the set of connections managed by this
+ * ConnectionManager.
+ *
+ * @param connection The connection to add.
+ * @param timeout Whether to immediately register this connection
+ * for an idle timeout callback.
+ */
+ void addConnection(ManagedConnection* connection,
+ bool timeout = false);
+
+ /**
+ * Schedule a timeout callback for a connection.
+ */
+ void scheduleTimeout(ManagedConnection* connection);
+
+ /*
+ * Schedule a callback on the wheel timer
+ */
+ void scheduleTimeout(folly::HHWheelTimer::Callback* callback,
+ std::chrono::milliseconds timeout);
+
+ /**
+ * Remove a connection from this ConnectionManager and, if
+ * applicable, cancel the pending timeout callback that the
+ * ConnectionManager has scheduled for the connection.
+ *
+ * @note This method does NOT destroy the connection.
+ */
+ void removeConnection(ManagedConnection* connection);
+
+ /* Begin gracefully shutting down connections in this ConnectionManager.
+ * Notify all connections of pending shutdown, and after idleGrace,
+ * begin closing idle connections.
+ */
+ void initiateGracefulShutdown(std::chrono::milliseconds idleGrace);
+
+ /**
+ * Destroy all connections Managed by this ConnectionManager, even
+ * the ones that are busy.
+ */
+ void dropAllConnections();
+
+ size_t getNumConnections() const { return conns_.size(); }
+
+ template <typename F>
+ void iterateConns(F func) {
+ auto it = conns_.begin();
+ while ( it != conns_.end()) {
+ func(&(*it));
+ it++;
+ }
+ }
+
+ private:
+ class CloseIdleConnsCallback :
+ public folly::EventBase::LoopCallback,
+ public folly::AsyncTimeout {
+ public:
+ explicit CloseIdleConnsCallback(ConnectionManager* manager)
+ : folly::AsyncTimeout(manager->eventBase_),
+ manager_(manager) {}
+
+ void runLoopCallback() noexcept override {
+ VLOG(3) << "Draining more conns from loop callback";
+ manager_->drainAllConnections();
+ }
+
+ void timeoutExpired() noexcept override {
+ VLOG(3) << "Idle grace expired";
+ manager_->drainAllConnections();
+ }
+
+ private:
+ ConnectionManager* manager_;
+ };
+
+ enum class ShutdownAction : uint8_t {
+ /**
+ * Drain part 1: inform remote that you will soon reject new requests.
+ */
+ DRAIN1 = 0,
+ /**
+ * Drain part 2: start rejecting new requests.
+ */
+ DRAIN2 = 1,
+ };
+
+ ~ConnectionManager() {}
+
+ ConnectionManager(const ConnectionManager&) = delete;
+ ConnectionManager& operator=(ConnectionManager&) = delete;
+
+ /**
+ * Destroy all connections managed by this ConnectionManager that
+ * are currently idle, as determined by a call to each ManagedConnection's
+ * isBusy() method.
+ */
+ void drainAllConnections();
+
+ /** All connections */
+ folly::CountedIntrusiveList<
+ ManagedConnection,&ManagedConnection::listHook_> conns_;
+
+ /** Connections that currently are registered for timeouts */
+ folly::HHWheelTimer::UniquePtr connTimeouts_;
+
+ /** Optional callback to notify of state changes */
+ Callback* callback_;
+
+ /** Event base in which we run */
+ folly::EventBase* eventBase_;
+
+ /** Iterator to the next connection to shed; used by drainAllConnections() */
+ folly::CountedIntrusiveList<
+ ManagedConnection,&ManagedConnection::listHook_>::iterator idleIterator_;
+ CloseIdleConnsCallback idleLoopCallback_;
+ ShutdownAction action_{ShutdownAction::DRAIN1};
+ std::chrono::milliseconds timeout_;
+};
+
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <string>
+
+namespace folly {
+
+struct dn_char_traits : public std::char_traits<char> {
+ static bool eq(char c1, char c2) {
+ return ::tolower(c1) == ::tolower(c2);
+ }
+
+ static bool ne(char c1, char c2) {
+ return ::tolower(c1) != ::tolower(c2);
+ }
+
+ static bool lt(char c1, char c2) {
+ return ::tolower(c1) < ::tolower(c2);
+ }
+
+ static int compare(const char* s1, const char* s2, size_t n) {
+ while (n--) {
+ if(::tolower(*s1) < ::tolower(*s2) ) {
+ return -1;
+ }
+ if(::tolower(*s1) > ::tolower(*s2) ) {
+ return 1;
+ }
+ ++s1;
+ ++s2;
+ }
+ return 0;
+ }
+
+ static const char* find(const char* s, size_t n, char a) {
+ char la = ::tolower(a);
+ while (n--) {
+ if(::tolower(*s) == la) {
+ return s;
+ } else {
+ ++s;
+ }
+ }
+ return nullptr;
+ }
+};
+
+// Case insensitive string
+typedef std::basic_string<char, dn_char_traits> DNString;
+
+struct DNStringHash : public std::hash<std::string> {
+ size_t operator()(const DNString& s) const noexcept {
+ size_t h = static_cast<size_t>(0xc70f6907UL);
+ const char* d = s.data();
+ for (size_t i = 0; i < s.length(); ++i) {
+ char a = ::tolower(*d++);
+ h = std::_Hash_impl::hash(&a, sizeof(a), h);
+ }
+ return h;
+ }
+};
+
+} // namespace
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/wangle/acceptor/LoadShedConfiguration.h>
+
+#include <folly/Conv.h>
+#include <openssl/ssl.h>
+
+using std::string;
+
+namespace folly {
+
+void LoadShedConfiguration::addWhitelistAddr(folly::StringPiece input) {
+ auto addr = input.str();
+ size_t separator = addr.find_first_of('/');
+ if (separator == string::npos) {
+ whitelistAddrs_.insert(SocketAddress(addr, 0));
+ } else {
+ unsigned prefixLen = folly::to<unsigned>(addr.substr(separator + 1));
+ addr.erase(separator);
+ whitelistNetworks_.insert(NetworkAddress(SocketAddress(addr, 0), prefixLen));
+ }
+}
+
+bool LoadShedConfiguration::isWhitelisted(const SocketAddress& address) const {
+ if (whitelistAddrs_.find(address) != whitelistAddrs_.end()) {
+ return true;
+ }
+ for (auto& network : whitelistNetworks_) {
+ if (network.contains(address)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+}
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <chrono>
+#include <folly/Range.h>
+#include <folly/SocketAddress.h>
+#include <glog/logging.h>
+#include <list>
+#include <set>
+#include <string>
+
+#include <folly/wangle/acceptor/NetworkAddress.h>
+
+namespace folly {
+
+/**
+ * Class that holds an LoadShed configuration for a service
+ */
+class LoadShedConfiguration {
+ public:
+
+ // Comparison function for SocketAddress that disregards the port
+ struct AddressOnlyCompare {
+ bool operator()(
+ const SocketAddress& addr1,
+ const SocketAddress& addr2) const {
+ return addr1.getIPAddress() < addr2.getIPAddress();
+ }
+ };
+
+ typedef std::set<SocketAddress, AddressOnlyCompare> AddressSet;
+ typedef std::set<NetworkAddress> NetworkSet;
+
+ LoadShedConfiguration() {}
+
+ ~LoadShedConfiguration() {}
+
+ void addWhitelistAddr(folly::StringPiece);
+
+ /**
+ * Set/get the set of IPs that should be whitelisted through even when we're
+ * trying to shed load.
+ */
+ void setWhitelistAddrs(const AddressSet& addrs) { whitelistAddrs_ = addrs; }
+ const AddressSet& getWhitelistAddrs() const { return whitelistAddrs_; }
+
+ /**
+ * Set/get the set of networks that should be whitelisted through even
+ * when we're trying to shed load.
+ */
+ void setWhitelistNetworks(const NetworkSet& networks) {
+ whitelistNetworks_ = networks;
+ }
+ const NetworkSet& getWhitelistNetworks() const { return whitelistNetworks_; }
+
+ /**
+ * Set/get the maximum number of downstream connections across all VIPs.
+ */
+ void setMaxConnections(uint64_t maxConns) { maxConnections_ = maxConns; }
+ uint64_t getMaxConnections() const { return maxConnections_; }
+
+ /**
+ * Set/get the maximum cpu usage.
+ */
+ void setMaxMemUsage(double max) {
+ CHECK(max >= 0);
+ CHECK(max <= 1);
+ maxMemUsage_ = max;
+ }
+ double getMaxMemUsage() const { return maxMemUsage_; }
+
+ /**
+ * Set/get the maximum memory usage.
+ */
+ void setMaxCpuUsage(double max) {
+ CHECK(max >= 0);
+ CHECK(max <= 1);
+ maxCpuUsage_ = max;
+ }
+ double getMaxCpuUsage() const { return maxCpuUsage_; }
+
+ void setLoadUpdatePeriod(std::chrono::milliseconds period) {
+ period_ = period;
+ }
+ std::chrono::milliseconds getLoadUpdatePeriod() const { return period_; }
+
+ bool isWhitelisted(const SocketAddress& addr) const;
+
+ private:
+
+ AddressSet whitelistAddrs_;
+ NetworkSet whitelistNetworks_;
+ uint64_t maxConnections_{0};
+ double maxMemUsage_;
+ double maxCpuUsage_;
+ std::chrono::milliseconds period_;
+};
+
+}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <folly/wangle/acceptor/ManagedConnection.h>
+
+#include <folly/wangle/acceptor/ConnectionManager.h>
+
+namespace folly { namespace wangle {
+
+ManagedConnection::ManagedConnection()
+ : connectionManager_(nullptr) {
+}
+
+ManagedConnection::~ManagedConnection() {
+ if (connectionManager_) {
+ connectionManager_->removeConnection(this);
+ }
+}
+
+void
+ManagedConnection::resetTimeout() {
+ if (connectionManager_) {
+ connectionManager_->scheduleTimeout(this);
+ }
+}
+
+void
+ManagedConnection::scheduleTimeout(
+ folly::HHWheelTimer::Callback* callback,
+ std::chrono::milliseconds timeout) {
+ if (connectionManager_) {
+ connectionManager_->scheduleTimeout(callback, timeout);
+ }
+}
+
+////////////////////// Globals /////////////////////
+
+std::ostream&
+operator<<(std::ostream& os, const ManagedConnection& conn) {
+ conn.describe(os);
+ return os;
+}
+
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <folly/IntrusiveList.h>
+#include <ostream>
+#include <folly/io/async/HHWheelTimer.h>
+#include <folly/io/async/DelayedDestruction.h>
+
+namespace folly { namespace wangle {
+
+class ConnectionManager;
+
+/**
+ * Interface describing a connection that can be managed by a
+ * container such as an Acceptor.
+ */
+class ManagedConnection:
+ public folly::HHWheelTimer::Callback,
+ public folly::DelayedDestruction {
+ public:
+
+ ManagedConnection();
+
+ // HHWheelTimer::Callback API (left for subclasses to implement).
+ virtual void timeoutExpired() noexcept = 0;
+
+ /**
+ * Print a human-readable description of the connection.
+ * @param os Destination stream.
+ */
+ virtual void describe(std::ostream& os) const = 0;
+
+ /**
+ * Check whether the connection has any requests outstanding.
+ */
+ virtual bool isBusy() const = 0;
+
+ /**
+ * Notify the connection that a shutdown is pending. This method will be
+ * called at the beginning of graceful shutdown.
+ */
+ virtual void notifyPendingShutdown() = 0;
+
+ /**
+ * Instruct the connection that it should shutdown as soon as it is
+ * safe. This is called after notifyPendingShutdown().
+ */
+ virtual void closeWhenIdle() = 0;
+
+ /**
+ * Forcibly drop a connection.
+ *
+ * If a request is in progress, this should cause the connection to be
+ * closed with a reset.
+ */
+ virtual void dropConnection() = 0;
+
+ /**
+ * Dump the state of the connection to the log
+ */
+ virtual void dumpConnectionState(uint8_t loglevel) = 0;
+
+ /**
+ * If the connection has a connection manager, reset the timeout
+ * countdown.
+ * @note If the connection manager doesn't have the connection scheduled
+ * for a timeout already, this method will schedule one. If the
+ * connection manager does have the connection connection scheduled
+ * for a timeout, this method will push back the timeout to N msec
+ * from now, where N is the connection manager's timer interval.
+ */
+ virtual void resetTimeout();
+
+ // Schedule an arbitrary timeout on the HHWheelTimer
+ virtual void scheduleTimeout(
+ folly::HHWheelTimer::Callback* callback,
+ std::chrono::milliseconds timeout);
+
+ ConnectionManager* getConnectionManager() {
+ return connectionManager_;
+ }
+
+ protected:
+ virtual ~ManagedConnection();
+
+ private:
+ friend class ConnectionManager;
+
+ void setConnectionManager(ConnectionManager* mgr) {
+ connectionManager_ = mgr;
+ }
+
+ ConnectionManager* connectionManager_;
+
+ folly::SafeIntrusiveListHook listHook_;
+};
+
+std::ostream& operator<<(std::ostream& os, const ManagedConnection& conn);
+
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/SocketAddress.h>
+
+namespace folly {
+
+/**
+ * A simple wrapper around SocketAddress that represents
+ * a network in CIDR notation
+ */
+class NetworkAddress {
+public:
+ /**
+ * Create a NetworkAddress for an addr/prefixLen
+ * @param addr IPv4 or IPv6 address of the network
+ * @param prefixLen Prefix length, in bits
+ */
+ NetworkAddress(const folly::SocketAddress& addr,
+ unsigned prefixLen):
+ addr_(addr), prefixLen_(prefixLen) {}
+
+ /** Get the network address */
+ const folly::SocketAddress& getAddress() const {
+ return addr_;
+ }
+
+ /** Get the prefix length in bits */
+ unsigned getPrefixLength() const { return prefixLen_; }
+
+ /** Check whether a given address lies within the network */
+ bool contains(const folly::SocketAddress& addr) const {
+ return addr_.prefixMatch(addr, prefixLen_);
+ }
+
+ /** Comparison operator to enable use in ordered collections */
+ bool operator<(const NetworkAddress& other) const {
+ if (addr_ < other.addr_) {
+ return true;
+ } else if (other.addr_ < addr_) {
+ return false;
+ } else {
+ return (prefixLen_ < other.prefixLen_);
+ }
+ }
+
+private:
+ folly::SocketAddress addr_;
+ unsigned prefixLen_;
+};
+
+} // namespace
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/wangle/ssl/SSLCacheOptions.h>
+#include <folly/wangle/ssl/SSLContextConfig.h>
+#include <folly/wangle/ssl/TLSTicketKeySeeds.h>
+#include <folly/wangle/ssl/SSLUtil.h>
+#include <folly/wangle/acceptor/SocketOptions.h>
+
+#include <boost/optional.hpp>
+#include <chrono>
+#include <fcntl.h>
+#include <folly/Random.h>
+#include <folly/SocketAddress.h>
+#include <folly/String.h>
+#include <folly/io/async/SSLContext.h>
+#include <list>
+#include <string>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/SSLContext.h>
+#include <folly/SocketAddress.h>
+
+namespace folly {
+
+/**
+ * Configuration for a single Acceptor.
+ *
+ * This configures not only accept behavior, but also some types of SSL
+ * behavior that may make sense to configure on a per-VIP basis (e.g. which
+ * cert(s) we use, etc).
+ */
+struct ServerSocketConfig {
+ ServerSocketConfig() {
+ // generate a single random current seed
+ uint8_t seed[32];
+ folly::Random::secureRandom(seed, sizeof(seed));
+ initialTicketSeeds.currentSeeds.push_back(
+ SSLUtil::hexlify(std::string((char *)seed, sizeof(seed))));
+ }
+
+ bool isSSL() const { return !(sslContextConfigs.empty()); }
+
+ /**
+ * Set/get the socket options to apply on all downstream connections.
+ */
+ void setSocketOptions(
+ const AsyncSocket::OptionMap& opts) {
+ socketOptions_ = filterIPSocketOptions(opts, bindAddress.getFamily());
+ }
+ AsyncSocket::OptionMap&
+ getSocketOptions() {
+ return socketOptions_;
+ }
+ const AsyncSocket::OptionMap&
+ getSocketOptions() const {
+ return socketOptions_;
+ }
+
+ bool hasExternalPrivateKey() const {
+ for (const auto& cfg : sslContextConfigs) {
+ if (!cfg.isLocalPrivateKey) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /**
+ * The name of this acceptor; used for stats/reporting purposes.
+ */
+ std::string name;
+
+ /**
+ * The depth of the accept queue backlog.
+ */
+ uint32_t acceptBacklog{1024};
+
+ /**
+ * The number of milliseconds a connection can be idle before we close it.
+ */
+ std::chrono::milliseconds connectionIdleTimeout{600000};
+
+ /**
+ * The address to bind to.
+ */
+ SocketAddress bindAddress;
+
+ /**
+ * Options for controlling the SSL cache.
+ */
+ SSLCacheOptions sslCacheOptions{std::chrono::seconds(600), 20480, 200};
+
+ /**
+ * The initial TLS ticket seeds.
+ */
+ TLSTicketKeySeeds initialTicketSeeds;
+
+ /**
+ * The configs for all the SSL_CTX for use by this Acceptor.
+ */
+ std::vector<SSLContextConfig> sslContextConfigs;
+
+ /**
+ * Determines if the Acceptor does strict checking when loading the SSL
+ * contexts.
+ */
+ bool strictSSL{true};
+
+ /**
+ * Maximum number of concurrent pending SSL handshakes
+ */
+ uint32_t maxConcurrentSSLHandshakes{30720};
+
+ private:
+ AsyncSocket::OptionMap socketOptions_;
+};
+
+} // folly
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/wangle/acceptor/SocketOptions.h>
+
+#include <netinet/tcp.h>
+#include <sys/socket.h>
+
+namespace folly {
+
+AsyncSocket::OptionMap filterIPSocketOptions(
+ const AsyncSocket::OptionMap& allOptions,
+ const int addrFamily) {
+ AsyncSocket::OptionMap opts;
+ int exclude;
+ if (addrFamily == AF_INET) {
+ exclude = IPPROTO_IPV6;
+ } else if (addrFamily == AF_INET6) {
+ exclude = IPPROTO_IP;
+ } else {
+ LOG(FATAL) << "Address family " << addrFamily << " was not IPv4 or IPv6";
+ return opts;
+ }
+ for (const auto& opt: allOptions) {
+ if (opt.first.level != exclude) {
+ opts[opt.first] = opt.second;
+ }
+ }
+ return opts;
+}
+
+}
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/io/async/AsyncSocket.h>
+
+namespace folly {
+
+/**
+ * Returns a copy of the socket options excluding options with the given
+ * level.
+ */
+AsyncSocket::OptionMap filterIPSocketOptions(
+ const AsyncSocket::OptionMap& allOptions,
+ const int addrFamily);
+
+}
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/wangle/acceptor/TransportInfo.h>
+
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <folly/io/async/AsyncSocket.h>
+
+using std::chrono::microseconds;
+using std::map;
+using std::string;
+
+namespace folly {
+
+bool TransportInfo::initWithSocket(const AsyncSocket* sock) {
+#if defined(__linux__) || defined(__FreeBSD__)
+ if (!TransportInfo::readTcpInfo(&tcpinfo, sock)) {
+ tcpinfoErrno = errno;
+ return false;
+ }
+ rtt = microseconds(tcpinfo.tcpi_rtt);
+ validTcpinfo = true;
+#else
+ tcpinfoErrno = EINVAL;
+ rtt = microseconds(-1);
+#endif
+ return true;
+}
+
+int64_t TransportInfo::readRTT(const AsyncSocket* sock) {
+#if defined(__linux__) || defined(__FreeBSD__)
+ struct tcp_info tcpinfo;
+ if (!TransportInfo::readTcpInfo(&tcpinfo, sock)) {
+ return -1;
+ }
+ return tcpinfo.tcpi_rtt;
+#else
+ return -1;
+#endif
+}
+
+#if defined(__linux__) || defined(__FreeBSD__)
+bool TransportInfo::readTcpInfo(struct tcp_info* tcpinfo,
+ const AsyncSocket* sock) {
+ socklen_t len = sizeof(struct tcp_info);
+ if (!sock) {
+ return false;
+ }
+ if (getsockopt(sock->getFd(), IPPROTO_TCP,
+ TCP_INFO, (void*) tcpinfo, &len) < 0) {
+ VLOG(4) << "Error calling getsockopt(): " << strerror(errno);
+ return false;
+ }
+ return true;
+}
+#endif
+
+} // folly
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/wangle/ssl/SSLUtil.h>
+
+#include <chrono>
+#include <netinet/tcp.h>
+#include <string>
+
+namespace folly {
+class AsyncSocket;
+
+/**
+ * A structure that encapsulates byte counters related to the HTTP headers.
+ */
+struct HTTPHeaderSize {
+ /**
+ * The number of bytes used to represent the header after compression or
+ * before decompression. If header compression is not supported, the value
+ * is set to 0.
+ */
+ size_t compressed{0};
+
+ /**
+ * The number of bytes used to represent the serialized header before
+ * compression or after decompression, in plain-text format.
+ */
+ size_t uncompressed{0};
+};
+
+struct TransportInfo {
+ /*
+ * timestamp of when the connection handshake was completed
+ */
+ std::chrono::steady_clock::time_point acceptTime{};
+
+ /*
+ * connection RTT (Round-Trip Time)
+ */
+ std::chrono::microseconds rtt{0};
+
+#if defined(__linux__) || defined(__FreeBSD__)
+ /*
+ * TCP information as fetched from getsockopt(2)
+ */
+ tcp_info tcpinfo {
+#if __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 17
+ 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 // 32
+#else
+ 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 // 29
+#endif // __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 17
+ };
+#endif // defined(__linux__) || defined(__FreeBSD__)
+
+ /*
+ * time for setting the connection, from the moment in was accepted until it
+ * is established.
+ */
+ std::chrono::milliseconds setupTime{0};
+
+ /*
+ * time for setting up the SSL connection or SSL handshake
+ */
+ std::chrono::milliseconds sslSetupTime{0};
+
+ /*
+ * The name of the SSL ciphersuite used by the transaction's
+ * transport. Returns null if the transport is not SSL.
+ */
+ const char* sslCipher{nullptr};
+
+ /*
+ * The SSL server name used by the transaction's
+ * transport. Returns null if the transport is not SSL.
+ */
+ const char* sslServerName{nullptr};
+
+ /*
+ * list of ciphers sent by the client
+ */
+ std::string sslClientCiphers{};
+
+ /*
+ * list of compression methods sent by the client
+ */
+ std::string sslClientComprMethods{};
+
+ /*
+ * list of TLS extensions sent by the client
+ */
+ std::string sslClientExts{};
+
+ /*
+ * hash of all the SSL parameters sent by the client
+ */
+ std::string sslSignature{};
+
+ /*
+ * list of ciphers supported by the server
+ */
+ std::string sslServerCiphers{};
+
+ /*
+ * guessed "(os) (browser)" based on SSL Signature
+ */
+ std::string guessedUserAgent{};
+
+ /**
+ * The result of SSL NPN negotiation.
+ */
+ std::string sslNextProtocol{};
+
+ /*
+ * total number of bytes sent over the connection
+ */
+ int64_t totalBytes{0};
+
+ /**
+ * header bytes read
+ */
+ HTTPHeaderSize ingressHeader;
+
+ /*
+ * header bytes written
+ */
+ HTTPHeaderSize egressHeader;
+
+ /*
+ * Here is how the timeToXXXByte variables are planned out:
+ * 1. All timeToXXXByte variables are measuring the ByteEvent from reqStart_
+ * 2. You can get the timing between two ByteEvents by calculating their
+ * differences. For example:
+ * timeToLastBodyByteAck - timeToFirstByte
+ * => Total time to deliver the body
+ * 3. The calculation in point (2) is typically done outside acceptor
+ *
+ * Future plan:
+ * We should log the timestamps (TimePoints) and allow
+ * the consumer to calculate the latency whatever it
+ * wants instead of calculating them in wangle, for the sake of flexibility.
+ * For example:
+ * 1. TimePoint reqStartTimestamp;
+ * 2. TimePoint firstHeaderByteSentTimestamp;
+ * 3. TimePoint firstBodyByteTimestamp;
+ * 3. TimePoint lastBodyByteTimestamp;
+ * 4. TimePoint lastBodyByteAckTimestamp;
+ */
+
+ /*
+ * time to first header byte written to the kernel send buffer
+ * NOTE: It is not 100% accurate since TAsyncSocket does not do
+ * do callback on partial write.
+ */
+ int32_t timeToFirstHeaderByte{-1};
+
+ /*
+ * time to first body byte written to the kernel send buffer
+ */
+ int32_t timeToFirstByte{-1};
+
+ /*
+ * time to last body byte written to the kernel send buffer
+ */
+ int32_t timeToLastByte{-1};
+
+ /*
+ * time to TCP Ack received for the last written body byte
+ */
+ int32_t timeToLastBodyByteAck{-1};
+
+ /*
+ * time it took the client to ACK the last byte, from the moment when the
+ * kernel sent the last byte to the client and until it received the ACK
+ * for that byte
+ */
+ int32_t lastByteAckLatency{-1};
+
+ /*
+ * time spent inside wangle
+ */
+ int32_t proxyLatency{-1};
+
+ /*
+ * time between connection accepted and client message headers completed
+ */
+ int32_t clientLatency{-1};
+
+ /*
+ * latency for communication with the server
+ */
+ int32_t serverLatency{-1};
+
+ /*
+ * time used to get a usable connection.
+ */
+ int32_t connectLatency{-1};
+
+ /*
+ * body bytes written
+ */
+ uint32_t egressBodySize{0};
+
+ /*
+ * value of errno in case of getsockopt() error
+ */
+ int tcpinfoErrno{0};
+
+ /*
+ * bytes read & written during SSL Setup
+ */
+ uint32_t sslSetupBytesWritten{0};
+ uint32_t sslSetupBytesRead{0};
+
+ /**
+ * SSL error detail
+ */
+ uint32_t sslError{0};
+
+ /**
+ * body bytes read
+ */
+ uint32_t ingressBodySize{0};
+
+ /*
+ * The SSL version used by the transaction's transport, in
+ * OpenSSL's format: 4 bits for the major version, followed by 4 bits
+ * for the minor version. Returns zero for non-SSL.
+ */
+ uint16_t sslVersion{0};
+
+ /*
+ * The SSL certificate size.
+ */
+ uint16_t sslCertSize{0};
+
+ /**
+ * response status code
+ */
+ uint16_t statusCode{0};
+
+ /*
+ * The SSL mode for the transaction's transport: new session,
+ * resumed session, or neither (non-SSL).
+ */
+ SSLResumeEnum sslResume{SSLResumeEnum::NA};
+
+ /*
+ * true if the tcpinfo was successfully read from the kernel
+ */
+ bool validTcpinfo{false};
+
+ /*
+ * true if the connection is SSL, false otherwise
+ */
+ bool ssl{false};
+
+ /*
+ * get the RTT value in milliseconds
+ */
+ std::chrono::milliseconds getRttMs() const {
+ return std::chrono::duration_cast<std::chrono::milliseconds>(rtt);
+ }
+
+ /*
+ * initialize the fields related with tcp_info
+ */
+ bool initWithSocket(const AsyncSocket* sock);
+
+ /*
+ * Get the kernel's estimate of round-trip time (RTT) to the transport's peer
+ * in microseconds. Returns -1 on error.
+ */
+ static int64_t readRTT(const AsyncSocket* sock);
+
+#if defined(__linux__) || defined(__FreeBSD__)
+ /*
+ * perform the getsockopt(2) syscall to fetch TCP info for a given socket
+ */
+ static bool readTcpInfo(struct tcp_info* tcpinfo,
+ const AsyncSocket* sock);
+#endif
+};
+
+} // folly
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "folly/wangle/bootstrap/ServerBootstrap.h"
+#include "folly/wangle/bootstrap/ClientBootstrap.h"
+#include "folly/wangle/channel/ChannelHandler.h"
+
+#include <glog/logging.h>
+#include <gtest/gtest.h>
+
+using namespace folly::wangle;
+using namespace folly;
+
+typedef ChannelPipeline<IOBufQueue&, std::unique_ptr<IOBuf>> Pipeline;
+
+class TestServer : public ServerBootstrap<Pipeline> {
+ Pipeline* newPipeline(std::shared_ptr<AsyncSocket>) {
+ return nullptr;
+ }
+};
+
+class TestClient : public ClientBootstrap<Pipeline> {
+ Pipeline* newPipeline(std::shared_ptr<AsyncSocket> sock) {
+ CHECK(sock->good());
+
+ // We probably aren't connected immedately, check after a small delay
+ EventBaseManager::get()->getEventBase()->runAfterDelay([sock](){
+ CHECK(sock->readable());
+ }, 100);
+ return nullptr;
+ }
+};
+
+class TestPipelineFactory : public PipelineFactory<Pipeline> {
+ public:
+ Pipeline* newPipeline(std::shared_ptr<AsyncSocket> sock) {
+ pipelines++;
+ return new Pipeline();
+ }
+ std::atomic<int> pipelines{0};
+};
+
+TEST(Bootstrap, Basic) {
+ TestServer server;
+ TestClient client;
+}
+
+TEST(Bootstrap, ServerWithPipeline) {
+ TestServer server;
+ server.childPipeline(std::make_shared<TestPipelineFactory>());
+ server.bind(0);
+ server.stop();
+}
+
+TEST(Bootstrap, ClientServerTest) {
+ TestServer server;
+ auto factory = std::make_shared<TestPipelineFactory>();
+ server.childPipeline(factory);
+ server.bind(0);
+ auto base = EventBaseManager::get()->getEventBase();
+
+ SocketAddress address;
+ server.getSockets()[0]->getAddress(&address);
+
+ TestClient client;
+ client.connect(address);
+ base->loop();
+ server.stop();
+
+ CHECK(factory->pipelines == 1);
+}
+
+TEST(Bootstrap, ClientConnectionManagerTest) {
+ // Create a single IO thread, and verify that
+ // client connections are pooled properly
+
+ TestServer server;
+ auto factory = std::make_shared<TestPipelineFactory>();
+ server.childPipeline(factory);
+ server.group(std::make_shared<IOThreadPoolExecutor>(1));
+ server.bind(0);
+ auto base = EventBaseManager::get()->getEventBase();
+
+ SocketAddress address;
+ server.getSockets()[0]->getAddress(&address);
+
+ TestClient client;
+ client.connect(address);
+
+ TestClient client2;
+ client2.connect(address);
+
+ base->loop();
+ server.stop();
+
+ CHECK(factory->pipelines == 2);
+}
+
+TEST(Bootstrap, ServerAcceptGroupTest) {
+ // Verify that server is using the accept IO group
+
+ TestServer server;
+ auto factory = std::make_shared<TestPipelineFactory>();
+ server.childPipeline(factory);
+ server.group(std::make_shared<IOThreadPoolExecutor>(1), nullptr);
+ server.bind(0);
+
+ SocketAddress address;
+ server.getSockets()[0]->getAddress(&address);
+
+ boost::barrier barrier(2);
+ auto thread = std::thread([&](){
+ TestClient client;
+ client.connect(address);
+ EventBaseManager::get()->getEventBase()->loop();
+ barrier.wait();
+ });
+ barrier.wait();
+ server.stop();
+ thread.join();
+
+ CHECK(factory->pipelines == 1);
+}
+
+TEST(Bootstrap, ServerAcceptGroup2Test) {
+ // Verify that server is using the accept IO group
+
+ // Check if reuse port is supported, if not, don't run this test
+ try {
+ EventBase base;
+ auto serverSocket = AsyncServerSocket::newSocket(&base);
+ serverSocket->bind(0);
+ serverSocket->listen(0);
+ serverSocket->startAccepting();
+ serverSocket->setReusePortEnabled(true);
+ serverSocket->stopAccepting();
+ } catch(...) {
+ LOG(INFO) << "Reuse port probably not supported";
+ return;
+ }
+
+ TestServer server;
+ auto factory = std::make_shared<TestPipelineFactory>();
+ server.childPipeline(factory);
+ server.group(std::make_shared<IOThreadPoolExecutor>(4), nullptr);
+ server.bind(0);
+
+ SocketAddress address;
+ server.getSockets()[0]->getAddress(&address);
+
+ TestClient client;
+ client.connect(address);
+ EventBaseManager::get()->getEventBase()->loop();
+
+ server.stop();
+
+ CHECK(factory->pipelines == 1);
+}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <folly/wangle/channel/ChannelPipeline.h>
+
+namespace folly {
+
+/*
+ * A thin wrapper around ChannelPipeline and AsyncSocket to match
+ * ServerBootstrap. On connect() a new pipeline is created.
+ */
+template <typename Pipeline>
+class ClientBootstrap {
+ public:
+ ClientBootstrap() {
+ }
+ ClientBootstrap* bind(int port) {
+ port_ = port;
+ return this;
+ }
+ ClientBootstrap* connect(SocketAddress address) {
+ pipeline_.reset(
+ newPipeline(
+ AsyncSocket::newSocket(EventBaseManager::get()->getEventBase(), address)
+ ));
+ return this;
+ }
+
+ virtual ~ClientBootstrap() {}
+
+ protected:
+ std::unique_ptr<Pipeline,
+ folly::DelayedDestruction::Destructor> pipeline_;
+
+ int port_;
+
+ virtual Pipeline* newPipeline(std::shared_ptr<AsyncSocket> socket) = 0;
+};
+
+} // namespace
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <folly/wangle/acceptor/Acceptor.h>
+#include <folly/io/async/EventBaseManager.h>
+#include <folly/wangle/concurrent/IOThreadPoolExecutor.h>
+#include <folly/wangle/acceptor/ManagedConnection.h>
+#include <folly/wangle/channel/ChannelPipeline.h>
+
+namespace folly {
+
+template <typename Pipeline>
+class ServerAcceptor : public Acceptor {
+ typedef std::unique_ptr<Pipeline,
+ folly::DelayedDestruction::Destructor> PipelinePtr;
+
+ class ServerConnection : public wangle::ManagedConnection {
+ public:
+ explicit ServerConnection(PipelinePtr pipeline)
+ : pipeline_(std::move(pipeline)) {}
+
+ ~ServerConnection() {
+ }
+
+ void timeoutExpired() noexcept {
+ }
+
+ void describe(std::ostream& os) const {}
+ bool isBusy() const {
+ return false;
+ }
+ void notifyPendingShutdown() {}
+ void closeWhenIdle() {}
+ void dropConnection() {}
+ void dumpConnectionState(uint8_t loglevel) {}
+ private:
+ PipelinePtr pipeline_;
+ };
+
+ public:
+ explicit ServerAcceptor(
+ std::shared_ptr<PipelineFactory<Pipeline>> pipelineFactory)
+ : Acceptor(ServerSocketConfig())
+ , pipelineFactory_(pipelineFactory) {
+ Acceptor::init(nullptr, &base_);
+ }
+
+ /* See Acceptor::onNewConnection for details */
+ void onNewConnection(
+ AsyncSocket::UniquePtr transport, const SocketAddress* address,
+ const std::string& nextProtocolName, const TransportInfo& tinfo) {
+
+ std::unique_ptr<Pipeline,
+ folly::DelayedDestruction::Destructor>
+ pipeline(pipelineFactory_->newPipeline(
+ std::shared_ptr<AsyncSocket>(
+ transport.release(),
+ folly::DelayedDestruction::Destructor())));
+ auto connection = new ServerConnection(std::move(pipeline));
+ Acceptor::addConnection(connection);
+ }
+
+ ~ServerAcceptor() {
+ Acceptor::dropAllConnections();
+ }
+
+ private:
+ EventBase base_;
+
+ std::shared_ptr<PipelineFactory<Pipeline>> pipelineFactory_;
+};
+
+template <typename Pipeline>
+class ServerAcceptorFactory : public AcceptorFactory {
+ public:
+ explicit ServerAcceptorFactory(
+ std::shared_ptr<PipelineFactory<Pipeline>> factory)
+ : factory_(factory) {}
+
+ std::shared_ptr<Acceptor> newAcceptor() {
+ return std::make_shared<ServerAcceptor<Pipeline>>(factory_);
+ }
+ private:
+ std::shared_ptr<PipelineFactory<Pipeline>> factory_;
+};
+
+class ServerWorkerFactory : public folly::wangle::ThreadFactory {
+ public:
+ explicit ServerWorkerFactory(std::shared_ptr<AcceptorFactory> acceptorFactory)
+ : internalFactory_(
+ std::make_shared<folly::wangle::NamedThreadFactory>("BootstrapWorker"))
+ , acceptorFactory_(acceptorFactory)
+ {}
+ virtual std::thread newThread(folly::Func&& func) override;
+
+ void setInternalFactory(
+ std::shared_ptr<folly::wangle::NamedThreadFactory> internalFactory);
+ void setNamePrefix(folly::StringPiece prefix);
+
+ template <typename F>
+ void forEachWorker(F&& f);
+
+ private:
+ std::shared_ptr<folly::wangle::NamedThreadFactory> internalFactory_;
+ folly::RWSpinLock workersLock_;
+ std::map<int32_t, std::shared_ptr<Acceptor>> workers_;
+ int32_t nextWorkerId_{0};
+
+ std::shared_ptr<AcceptorFactory> acceptorFactory_;
+};
+
+template <typename F>
+void ServerWorkerFactory::forEachWorker(F&& f) {
+ folly::RWSpinLock::ReadHolder guard(workersLock_);
+ for (const auto& kv : workers_) {
+ f(kv.second.get());
+ }
+}
+
+} // namespace
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <folly/wangle/bootstrap/ServerBootstrap.h>
+#include <folly/wangle/concurrent/NamedThreadFactory.h>
+#include <folly/io/async/EventBaseManager.h>
+
+namespace folly {
+
+std::thread ServerWorkerFactory::newThread(
+ folly::Func&& func) {
+ auto id = nextWorkerId_++;
+ auto worker = acceptorFactory_->newAcceptor();
+ {
+ folly::RWSpinLock::WriteHolder guard(workersLock_);
+ workers_.insert({id, worker});
+ }
+ return internalFactory_->newThread([=](){
+ EventBaseManager::get()->setEventBase(worker->getEventBase(), false);
+ func();
+ EventBaseManager::get()->clearEventBase();
+
+ worker->drainAllConnections();
+ {
+ folly::RWSpinLock::WriteHolder guard(workersLock_);
+ workers_.erase(id);
+ }
+ });
+}
+
+void ServerWorkerFactory::setInternalFactory(
+ std::shared_ptr<wangle::NamedThreadFactory> internalFactory) {
+ CHECK(workers_.empty());
+ internalFactory_ = internalFactory;
+}
+
+void ServerWorkerFactory::setNamePrefix(folly::StringPiece prefix) {
+ CHECK(workers_.empty());
+ internalFactory_->setNamePrefix(prefix);
+}
+
+} // namespace
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <folly/wangle/bootstrap/ServerBootstrap-inl.h>
+#include <boost/thread.hpp>
+
+namespace folly {
+
+/*
+ * ServerBootstrap is a parent class intended to set up a
+ * high-performance TCP accepting server. It will manage a pool of
+ * accepting threads, any number of accepting sockets, a pool of
+ * IO-worker threads, and connection pool for each IO thread for you.
+ *
+ * The output is given as a ChannelPipeline template: given a
+ * PipelineFactory, it will create a new pipeline for each connection,
+ * and your server can handle the incoming bytes.
+ *
+ * BACKWARDS COMPATIBLITY: for servers already taking a pool of
+ * Acceptor objects, an AcceptorFactory can be given directly instead
+ * of a pipeline factory.
+ */
+template <typename Pipeline>
+class ServerBootstrap {
+ public:
+ /* TODO(davejwatson)
+ *
+ * If there is any work to be done BEFORE handing the work to IO
+ * threads, this handler is where the pipeline to do it would be
+ * set.
+ *
+ * This could be used for things like logging, load balancing, or
+ * advanced load balancing on IO threads. Netty also provides this.
+ */
+ ServerBootstrap* handler() {
+ return this;
+ }
+
+ /*
+ * BACKWARDS COMPATIBILITY - an acceptor factory can be set. Your
+ * Acceptor is responsible for managing the connection pool.
+ *
+ * @param childHandler - acceptor factory to call for each IO thread
+ */
+ ServerBootstrap* childHandler(std::shared_ptr<AcceptorFactory> childHandler) {
+ acceptorFactory_ = childHandler;
+ return this;
+ }
+
+ /*
+ * Set a pipeline factory that will be called for each new connection
+ *
+ * @param factory pipeline factory to use for each new connection
+ */
+ ServerBootstrap* childPipeline(
+ std::shared_ptr<PipelineFactory<Pipeline>> factory) {
+ pipelineFactory_ = factory;
+ return this;
+ }
+
+ /*
+ * Set the IO executor. If not set, a default one will be created
+ * with one thread per core.
+ *
+ * @param io_group - io executor to use for IO threads.
+ */
+ ServerBootstrap* group(
+ std::shared_ptr<folly::wangle::IOThreadPoolExecutor> io_group) {
+ return group(nullptr, io_group);
+ }
+
+ /*
+ * Set the acceptor executor, and IO executor.
+ *
+ * If no acceptor executor is set, a single thread will be created for accepts
+ * If no IO executor is set, a default of one thread per core will be created
+ *
+ * @param group - acceptor executor to use for acceptor threads.
+ * @param io_group - io executor to use for IO threads.
+ */
+ ServerBootstrap* group(
+ std::shared_ptr<folly::wangle::IOThreadPoolExecutor> accept_group,
+ std::shared_ptr<wangle::IOThreadPoolExecutor> io_group) {
+ if (!accept_group) {
+ accept_group = std::make_shared<folly::wangle::IOThreadPoolExecutor>(
+ 1, std::make_shared<wangle::NamedThreadFactory>("Acceptor Thread"));
+ }
+ if (!io_group) {
+ io_group = std::make_shared<folly::wangle::IOThreadPoolExecutor>(
+ 32, std::make_shared<wangle::NamedThreadFactory>("IO Thread"));
+ }
+ auto factoryBase = io_group->getThreadFactory();
+ CHECK(factoryBase);
+ auto factory = std::dynamic_pointer_cast<folly::wangle::NamedThreadFactory>(
+ factoryBase);
+ CHECK(factory); // Must be named thread factory
+
+ CHECK(acceptorFactory_ || pipelineFactory_);
+
+ if (acceptorFactory_) {
+ workerFactory_ = std::make_shared<ServerWorkerFactory>(
+ acceptorFactory_);
+ } else {
+ workerFactory_ = std::make_shared<ServerWorkerFactory>(
+ std::make_shared<ServerAcceptorFactory<Pipeline>>(pipelineFactory_));
+ }
+ workerFactory_->setInternalFactory(factory);
+
+ acceptor_group_ = accept_group;
+ io_group_ = io_group;
+
+ auto numThreads = io_group_->numThreads();
+ io_group_->setNumThreads(0);
+ io_group_->setThreadFactory(workerFactory_);
+ io_group_->setNumThreads(numThreads);
+
+ return this;
+ }
+
+ /*
+ * Bind to a port and start listening.
+ * One of childPipeline or childHandler must be called before bind
+ *
+ * @param port Port to listen on
+ */
+ void bind(int port) {
+ // TODO take existing socket
+
+ if (!workerFactory_) {
+ group(nullptr);
+ }
+
+ bool reusePort = false;
+ if (acceptor_group_->numThreads() > 1) {
+ reusePort = true;
+ }
+
+ std::mutex sock_lock;
+ std::vector<std::shared_ptr<folly::AsyncServerSocket>> new_sockets;
+
+ auto startupFunc = [&](std::shared_ptr<boost::barrier> barrier){
+ auto socket = folly::AsyncServerSocket::newSocket();
+ sock_lock.lock();
+ new_sockets.push_back(socket);
+ sock_lock.unlock();
+ socket->setReusePortEnabled(reusePort);
+ socket->attachEventBase(EventBaseManager::get()->getEventBase());
+ socket->bind(port);
+ // TODO Take ServerSocketConfig
+ socket->listen(1024);
+ socket->startAccepting();
+
+ if (port == 0) {
+ SocketAddress address;
+ socket->getAddress(&address);
+ port = address.getPort();
+ }
+
+ barrier->wait();
+ };
+
+ auto bind0 = std::make_shared<boost::barrier>(2);
+ acceptor_group_->add(std::bind(startupFunc, bind0));
+ bind0->wait();
+
+ auto barrier = std::make_shared<boost::barrier>(acceptor_group_->numThreads());
+ for (int i = 1; i < acceptor_group_->numThreads(); i++) {
+ acceptor_group_->add(std::bind(startupFunc, barrier));
+ }
+ barrier->wait();
+
+ // Startup all the threads
+ for(auto socket : new_sockets) {
+ workerFactory_->forEachWorker([this, socket](Acceptor* worker){
+ socket->getEventBase()->runInEventBaseThread([this, worker, socket](){
+ socket->addAcceptCallback(worker, worker->getEventBase());
+ });
+ });
+ }
+
+ for (auto& socket : new_sockets) {
+ sockets_.push_back(socket);
+ }
+ }
+
+ /*
+ * Stop listening on all sockets.
+ */
+ void stop() {
+ auto barrier = std::make_shared<boost::barrier>(sockets_.size() + 1);
+ for (auto socket : sockets_) {
+ socket->getEventBase()->runInEventBaseThread([barrier, socket]() {
+ socket->stopAccepting();
+ socket->detachEventBase();
+ barrier->wait();
+ });
+ }
+ barrier->wait();
+ sockets_.clear();
+
+ acceptor_group_->join();
+ io_group_->join();
+ }
+
+ /*
+ * Get the list of listening sockets
+ */
+ std::vector<std::shared_ptr<folly::AsyncServerSocket>>&
+ getSockets() {
+ return sockets_;
+ }
+
+ private:
+ std::shared_ptr<wangle::IOThreadPoolExecutor> acceptor_group_;
+ std::shared_ptr<wangle::IOThreadPoolExecutor> io_group_;
+
+ std::shared_ptr<ServerWorkerFactory> workerFactory_;
+ std::vector<std::shared_ptr<folly::AsyncServerSocket>> sockets_;
+
+ std::shared_ptr<AcceptorFactory> acceptorFactory_;
+ std::shared_ptr<PipelineFactory<Pipeline>> pipelineFactory_;
+};
+
+} // namespace
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <folly/wangle/channel/ChannelHandler.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/EventBase.h>
+#include <folly/io/async/EventBaseManager.h>
+#include <folly/io/IOBuf.h>
+#include <folly/io/IOBufQueue.h>
+
+namespace folly { namespace wangle {
+
+class AsyncSocketHandler
+ : public folly::wangle::BytesToBytesHandler,
+ public AsyncSocket::ReadCallback {
+ public:
+ explicit AsyncSocketHandler(
+ std::shared_ptr<AsyncSocket> socket)
+ : socket_(std::move(socket)) {}
+
+ AsyncSocketHandler(AsyncSocketHandler&&) = default;
+
+ ~AsyncSocketHandler() {
+ if (socket_) {
+ detachReadCallback();
+ }
+ }
+
+ void attachReadCallback() {
+ socket_->setReadCB(socket_->good() ? this : nullptr);
+ }
+
+ void detachReadCallback() {
+ if (socket_->getReadCallback() == this) {
+ socket_->setReadCB(nullptr);
+ }
+ }
+
+ void attachEventBase(folly::EventBase* eventBase) {
+ if (eventBase && !socket_->getEventBase()) {
+ socket_->attachEventBase(eventBase);
+ }
+ }
+
+ void detachEventBase() {
+ detachReadCallback();
+ if (socket_->getEventBase()) {
+ socket_->detachEventBase();
+ }
+ }
+
+ void attachPipeline(Context* ctx) override {
+ CHECK(!ctx_);
+ ctx_ = ctx;
+ }
+
+ folly::wangle::Future<void> write(
+ Context* ctx,
+ std::unique_ptr<folly::IOBuf> buf) override {
+ if (UNLIKELY(!buf)) {
+ return folly::wangle::makeFuture();
+ }
+
+ if (!socket_->good()) {
+ VLOG(5) << "socket is closed in write()";
+ return folly::wangle::makeFuture<void>(AsyncSocketException(
+ AsyncSocketException::AsyncSocketExceptionType::NOT_OPEN,
+ "socket is closed in write()"));
+ }
+
+ auto cb = new WriteCallback();
+ auto future = cb->promise_.getFuture();
+ socket_->writeChain(cb, std::move(buf), ctx->getWriteFlags());
+ return future;
+ };
+
+ folly::wangle::Future<void> close(Context* ctx) {
+ if (socket_) {
+ detachReadCallback();
+ socket_->closeNow();
+ }
+ return folly::wangle::makeFuture();
+ }
+
+ // Must override to avoid warnings about hidden overloaded virtual due to
+ // AsyncSocket::ReadCallback::readEOF()
+ void readEOF(Context* ctx) override {
+ ctx->fireReadEOF();
+ }
+
+ void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ const auto readBufferSettings = ctx_->getReadBufferSettings();
+ const auto ret = bufQueue_.preallocate(
+ readBufferSettings.first,
+ readBufferSettings.second);
+ *bufReturn = ret.first;
+ *lenReturn = ret.second;
+ }
+
+ void readDataAvailable(size_t len) noexcept override {
+ bufQueue_.postallocate(len);
+ ctx_->fireRead(bufQueue_);
+ }
+
+ void readEOF() noexcept override {
+ ctx_->fireReadEOF();
+ }
+
+ void readErr(const AsyncSocketException& ex)
+ noexcept override {
+ ctx_->fireReadException(make_exception_wrapper<AsyncSocketException>(ex));
+ }
+
+ private:
+ class WriteCallback : private AsyncSocket::WriteCallback {
+ void writeSuccess() noexcept override {
+ promise_.setValue();
+ delete this;
+ }
+
+ void writeErr(size_t bytesWritten,
+ const AsyncSocketException& ex)
+ noexcept override {
+ promise_.setException(ex);
+ delete this;
+ }
+
+ private:
+ friend class AsyncSocketHandler;
+ folly::wangle::Promise<void> promise_;
+ };
+
+ Context* ctx_{nullptr};
+ folly::IOBufQueue bufQueue_;
+ std::shared_ptr<AsyncSocket> socket_{nullptr};
+};
+
+}}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <folly/wangle/futures/Future.h>
+#include <folly/wangle/channel/ChannelPipeline.h>
+#include <folly/io/IOBuf.h>
+#include <folly/io/IOBufQueue.h>
+
+namespace folly { namespace wangle {
+
+template <class Rin, class Rout = Rin, class Win = Rout, class Wout = Rin>
+class ChannelHandler {
+ public:
+ typedef Rin rin;
+ typedef Rout rout;
+ typedef Win win;
+ typedef Wout wout;
+ typedef ChannelHandlerContext<Rout, Wout> Context;
+ virtual ~ChannelHandler() {}
+
+ virtual void read(Context* ctx, Rin msg) = 0;
+ virtual void readEOF(Context* ctx) {
+ ctx->fireReadEOF();
+ }
+ virtual void readException(Context* ctx, exception_wrapper e) {
+ ctx->fireReadException(std::move(e));
+ }
+
+ virtual Future<void> write(Context* ctx, Win msg) = 0;
+ virtual Future<void> close(Context* ctx) {
+ return ctx->fireClose();
+ }
+
+ virtual void attachPipeline(Context* ctx) {}
+ virtual void attachTransport(Context* ctx) {}
+
+ virtual void detachPipeline(Context* ctx) {}
+ virtual void detachTransport(Context* ctx) {}
+
+ /*
+ // Other sorts of things we might want, all shamelessly stolen from Netty
+ // inbound
+ virtual void exceptionCaught(
+ ChannelHandlerContext* ctx,
+ exception_wrapper e) {}
+ virtual void channelRegistered(ChannelHandlerContext* ctx) {}
+ virtual void channelUnregistered(ChannelHandlerContext* ctx) {}
+ virtual void channelActive(ChannelHandlerContext* ctx) {}
+ virtual void channelInactive(ChannelHandlerContext* ctx) {}
+ virtual void channelReadComplete(ChannelHandlerContext* ctx) {}
+ virtual void userEventTriggered(ChannelHandlerContext* ctx, void* evt) {}
+ virtual void channelWritabilityChanged(ChannelHandlerContext* ctx) {}
+
+ // outbound
+ virtual Future<void> bind(
+ ChannelHandlerContext* ctx,
+ SocketAddress localAddress) {}
+ virtual Future<void> connect(
+ ChannelHandlerContext* ctx,
+ SocketAddress remoteAddress, SocketAddress localAddress) {}
+ virtual Future<void> disconnect(ChannelHandlerContext* ctx) {}
+ virtual Future<void> deregister(ChannelHandlerContext* ctx) {}
+ virtual Future<void> read(ChannelHandlerContext* ctx) {}
+ virtual void flush(ChannelHandlerContext* ctx) {}
+ */
+};
+
+template <class R, class W = R>
+class ChannelHandlerAdapter : public ChannelHandler<R, R, W, W> {
+ public:
+ typedef typename ChannelHandler<R, R, W, W>::Context Context;
+
+ void read(Context* ctx, R msg) override {
+ ctx->fireRead(std::forward<R>(msg));
+ }
+
+ Future<void> write(Context* ctx, W msg) override {
+ return ctx->fireWrite(std::forward<W>(msg));
+ }
+};
+
+typedef ChannelHandlerAdapter<IOBufQueue&, std::unique_ptr<IOBuf>>
+BytesToBytesHandler;
+
+template <class Handler, bool Shared = true>
+class ChannelHandlerPtr : public ChannelHandler<
+ typename Handler::rin,
+ typename Handler::rout,
+ typename Handler::win,
+ typename Handler::wout> {
+ public:
+ typedef typename std::conditional<
+ Shared,
+ std::shared_ptr<Handler>,
+ Handler*>::type
+ HandlerPtr;
+
+ typedef typename Handler::Context Context;
+
+ explicit ChannelHandlerPtr(HandlerPtr handler)
+ : handler_(std::move(handler)) {}
+
+ void setHandler(HandlerPtr handler) {
+ if (handler == handler_) {
+ return;
+ }
+ if (handler_ && ctx_) {
+ handler_->detachPipeline(ctx_);
+ }
+ handler_ = std::move(handler);
+ if (handler_ && ctx_) {
+ handler_->attachPipeline(ctx_);
+ if (ctx_->getTransport()) {
+ handler_->attachTransport(ctx_);
+ }
+ }
+ }
+
+ void attachPipeline(Context* ctx) override {
+ ctx_ = ctx;
+ if (handler_) {
+ handler_->attachPipeline(ctx_);
+ }
+ }
+
+ void attachTransport(Context* ctx) override {
+ ctx_ = ctx;
+ if (handler_) {
+ handler_->attachTransport(ctx_);
+ }
+ }
+
+ void detachPipeline(Context* ctx) override {
+ ctx_ = ctx;
+ if (handler_) {
+ handler_->detachPipeline(ctx_);
+ }
+ }
+
+ void detachTransport(Context* ctx) override {
+ ctx_ = ctx;
+ if (handler_) {
+ handler_->detachTransport(ctx_);
+ }
+ }
+
+ void read(Context* ctx, typename Handler::rin msg) override {
+ DCHECK(handler_);
+ handler_->read(ctx, std::forward<typename Handler::rin>(msg));
+ }
+
+ void readEOF(Context* ctx) override {
+ DCHECK(handler_);
+ handler_->readEOF(ctx);
+ }
+
+ void readException(Context* ctx, exception_wrapper e) override {
+ DCHECK(handler_);
+ handler_->readException(ctx, std::move(e));
+ }
+
+ Future<void> write(Context* ctx, typename Handler::win msg) override {
+ DCHECK(handler_);
+ return handler_->write(ctx, std::forward<typename Handler::win>(msg));
+ }
+
+ Future<void> close(Context* ctx) override {
+ DCHECK(handler_);
+ return handler_->close(ctx);
+ }
+
+ private:
+ Context* ctx_;
+ HandlerPtr handler_;
+};
+
+}}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <folly/io/async/AsyncTransport.h>
+#include <folly/wangle/futures/Future.h>
+#include <folly/ExceptionWrapper.h>
+
+namespace folly { namespace wangle {
+
+template <class In, class Out>
+class ChannelHandlerContext {
+ public:
+ virtual ~ChannelHandlerContext() {}
+
+ virtual void fireRead(In msg) = 0;
+ virtual void fireReadEOF() = 0;
+ virtual void fireReadException(exception_wrapper e) = 0;
+
+ virtual Future<void> fireWrite(Out msg) = 0;
+ virtual Future<void> fireClose() = 0;
+
+ virtual std::shared_ptr<AsyncTransport> getTransport() = 0;
+
+ virtual void setWriteFlags(WriteFlags flags) = 0;
+ virtual WriteFlags getWriteFlags() = 0;
+
+ virtual void setReadBufferSettings(
+ uint64_t minAvailable,
+ uint64_t allocationSize) = 0;
+ virtual std::pair<uint64_t, uint64_t> getReadBufferSettings() = 0;
+
+ /* TODO
+ template <class H>
+ virtual void addHandlerBefore(H&&) {}
+ template <class H>
+ virtual void addHandlerAfter(H&&) {}
+ template <class H>
+ virtual void replaceHandler(H&&) {}
+ virtual void removeHandler() {}
+ */
+};
+
+class PipelineContext {
+ public:
+ virtual ~PipelineContext() {}
+
+ virtual void attachTransport() = 0;
+ virtual void detachTransport() = 0;
+
+ void link(PipelineContext* other) {
+ setNextIn(other);
+ other->setNextOut(this);
+ }
+
+ protected:
+ virtual void setNextIn(PipelineContext* ctx) = 0;
+ virtual void setNextOut(PipelineContext* ctx) = 0;
+};
+
+template <class In>
+class InboundChannelHandlerContext {
+ public:
+ virtual ~InboundChannelHandlerContext() {}
+ virtual void read(In msg) = 0;
+ virtual void readEOF() = 0;
+ virtual void readException(exception_wrapper e) = 0;
+};
+
+template <class Out>
+class OutboundChannelHandlerContext {
+ public:
+ virtual ~OutboundChannelHandlerContext() {}
+ virtual Future<void> write(Out msg) = 0;
+ virtual Future<void> close() = 0;
+};
+
+template <class P, class H>
+class ContextImpl : public ChannelHandlerContext<typename H::rout,
+ typename H::wout>,
+ public InboundChannelHandlerContext<typename H::rin>,
+ public OutboundChannelHandlerContext<typename H::win>,
+ public PipelineContext {
+ public:
+ typedef typename H::rin Rin;
+ typedef typename H::rout Rout;
+ typedef typename H::win Win;
+ typedef typename H::wout Wout;
+
+ template <class HandlerArg>
+ explicit ContextImpl(P* pipeline, HandlerArg&& handlerArg)
+ : pipeline_(pipeline),
+ handler_(std::forward<HandlerArg>(handlerArg)) {
+ handler_.attachPipeline(this);
+ }
+
+ ~ContextImpl() {
+ handler_.detachPipeline(this);
+ }
+
+ H* getHandler() {
+ return &handler_;
+ }
+
+ // PipelineContext overrides
+ void setNextIn(PipelineContext* ctx) override {
+ auto nextIn = dynamic_cast<InboundChannelHandlerContext<Rout>*>(ctx);
+ if (nextIn) {
+ nextIn_ = nextIn;
+ } else {
+ throw std::invalid_argument("wrong type in setNextIn");
+ }
+ }
+
+ void setNextOut(PipelineContext* ctx) override {
+ auto nextOut = dynamic_cast<OutboundChannelHandlerContext<Wout>*>(ctx);
+ if (nextOut) {
+ nextOut_ = nextOut;
+ } else {
+ throw std::invalid_argument("wrong type in setNextOut");
+ }
+ }
+
+ void attachTransport() override {
+ typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+ handler_.attachTransport(this);
+ }
+
+ void detachTransport() override {
+ typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+ handler_.detachTransport(this);
+ }
+
+ // ChannelHandlerContext overrides
+ void fireRead(Rout msg) override {
+ typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+ if (nextIn_) {
+ nextIn_->read(std::forward<Rout>(msg));
+ } else {
+ LOG(WARNING) << "read reached end of pipeline";
+ }
+ }
+
+ void fireReadEOF() override {
+ typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+ if (nextIn_) {
+ nextIn_->readEOF();
+ } else {
+ LOG(WARNING) << "readEOF reached end of pipeline";
+ }
+ }
+
+ void fireReadException(exception_wrapper e) override {
+ typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+ if (nextIn_) {
+ nextIn_->readException(std::move(e));
+ } else {
+ LOG(WARNING) << "readException reached end of pipeline";
+ }
+ }
+
+ Future<void> fireWrite(Wout msg) override {
+ typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+ if (nextOut_) {
+ return nextOut_->write(std::forward<Wout>(msg));
+ } else {
+ LOG(WARNING) << "write reached end of pipeline";
+ return makeFuture();
+ }
+ }
+
+ Future<void> fireClose() override {
+ typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+ if (nextOut_) {
+ return nextOut_->close();
+ } else {
+ LOG(WARNING) << "close reached end of pipeline";
+ return makeFuture();
+ }
+ }
+
+ std::shared_ptr<AsyncTransport> getTransport() override {
+ return pipeline_->getTransport();
+ }
+
+ void setWriteFlags(WriteFlags flags) override {
+ pipeline_->setWriteFlags(flags);
+ }
+
+ WriteFlags getWriteFlags() override {
+ return pipeline_->getWriteFlags();
+ }
+
+ void setReadBufferSettings(
+ uint64_t minAvailable,
+ uint64_t allocationSize) override {
+ pipeline_->setReadBufferSettings(minAvailable, allocationSize);
+ }
+
+ std::pair<uint64_t, uint64_t> getReadBufferSettings() override {
+ return pipeline_->getReadBufferSettings();
+ }
+
+ // InboundChannelHandlerContext overrides
+ void read(Rin msg) override {
+ typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+ handler_.read(this, std::forward<Rin>(msg));
+ }
+
+ void readEOF() override {
+ typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+ handler_.readEOF(this);
+ }
+
+ void readException(exception_wrapper e) override {
+ typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+ handler_.readException(this, std::move(e));
+ }
+
+ // OutboundChannelHandlerContext overrides
+ Future<void> write(Win msg) override {
+ typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+ return handler_.write(this, std::forward<Win>(msg));
+ }
+
+ Future<void> close() override {
+ typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+ return handler_.close(this);
+ }
+
+ private:
+ P* pipeline_;
+ H handler_;
+ InboundChannelHandlerContext<Rout>* nextIn_{nullptr};
+ OutboundChannelHandlerContext<Wout>* nextOut_{nullptr};
+};
+
+}}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <folly/wangle/channel/ChannelHandlerContext.h>
+#include <folly/wangle/futures/Future.h>
+#include <folly/io/async/AsyncTransport.h>
+#include <folly/io/async/DelayedDestruction.h>
+#include <folly/ExceptionWrapper.h>
+#include <folly/Memory.h>
+#include <glog/logging.h>
+
+namespace folly { namespace wangle {
+
+/*
+ * R is the inbound type, i.e. inbound calls start with pipeline.read(R)
+ * W is the outbound type, i.e. outbound calls start with pipeline.write(W)
+ */
+template <class R, class W, class... Handlers>
+class ChannelPipeline;
+
+template <class R, class W>
+class ChannelPipeline<R, W> : public DelayedDestruction {
+ public:
+ ChannelPipeline() {}
+ ~ChannelPipeline() {}
+
+ std::shared_ptr<AsyncTransport> getTransport() {
+ return transport_;
+ }
+
+ void setWriteFlags(WriteFlags flags) {
+ writeFlags_ = flags;
+ }
+
+ WriteFlags getWriteFlags() {
+ return writeFlags_;
+ }
+
+ void setReadBufferSettings(uint64_t minAvailable, uint64_t allocationSize) {
+ readBufferSettings_ = std::make_pair(minAvailable, allocationSize);
+ }
+
+ std::pair<uint64_t, uint64_t> getReadBufferSettings() {
+ return readBufferSettings_;
+ }
+
+ void read(R msg) {
+ front_->read(std::forward<R>(msg));
+ }
+
+ void readEOF() {
+ front_->readEOF();
+ }
+
+ void readException(exception_wrapper e) {
+ front_->readException(std::move(e));
+ }
+
+ Future<void> write(W msg) {
+ return back_->write(std::forward<W>(msg));
+ }
+
+ Future<void> close() {
+ return back_->close();
+ }
+
+ template <class H>
+ ChannelPipeline& addBack(H&& handler) {
+ ctxs_.push_back(folly::make_unique<ContextImpl<ChannelPipeline, H>>(
+ this, std::forward<H>(handler)));
+ return *this;
+ }
+
+ template <class H>
+ ChannelPipeline& addFront(H&& handler) {
+ ctxs_.insert(
+ ctxs_.begin(),
+ folly::make_unique<ContextImpl<ChannelPipeline, H>>(
+ this,
+ std::forward<H>(handler)));
+ return *this;
+ }
+
+ template <class H>
+ H* getHandler(int i) {
+ auto ctx = dynamic_cast<ContextImpl<ChannelPipeline, H>*>(ctxs_[i].get());
+ CHECK(ctx);
+ return ctx->getHandler();
+ }
+
+ void finalize() {
+ finalizeHelper();
+ InboundChannelHandlerContext<R>* front;
+ front_ = dynamic_cast<InboundChannelHandlerContext<R>*>(
+ ctxs_.front().get());
+ if (!front_) {
+ throw std::invalid_argument("wrong type for first handler");
+ }
+ }
+
+ protected:
+ explicit ChannelPipeline(bool shouldFinalize) {
+ CHECK(!shouldFinalize);
+ }
+
+ void finalizeHelper() {
+ if (ctxs_.empty()) {
+ return;
+ }
+
+ for (int i = 0; i < ctxs_.size() - 1; i++) {
+ ctxs_[i]->link(ctxs_[i+1].get());
+ }
+
+ back_ = dynamic_cast<OutboundChannelHandlerContext<W>*>(ctxs_.back().get());
+ if (!back_) {
+ throw std::invalid_argument("wrong type for last handler");
+ }
+ }
+
+ PipelineContext* getLocalFront() {
+ return ctxs_.empty() ? nullptr : ctxs_.front().get();
+ }
+
+ static const bool is_end{true};
+
+ std::shared_ptr<AsyncTransport> transport_;
+ WriteFlags writeFlags_{WriteFlags::NONE};
+ std::pair<uint64_t, uint64_t> readBufferSettings_{2048, 2048};
+
+ void attachPipeline() {}
+
+ void attachTransport(
+ std::shared_ptr<AsyncTransport> transport) {
+ transport_ = std::move(transport);
+ }
+
+ void detachTransport() {
+ transport_ = nullptr;
+ }
+
+ OutboundChannelHandlerContext<W>* back_{nullptr};
+
+ private:
+ InboundChannelHandlerContext<R>* front_{nullptr};
+ std::vector<std::unique_ptr<PipelineContext>> ctxs_;
+};
+
+template <class R, class W, class Handler, class... Handlers>
+class ChannelPipeline<R, W, Handler, Handlers...>
+ : public ChannelPipeline<R, W, Handlers...> {
+ protected:
+ template <class HandlerArg, class... HandlersArgs>
+ ChannelPipeline(
+ bool shouldFinalize,
+ HandlerArg&& handlerArg,
+ HandlersArgs&&... handlersArgs)
+ : ChannelPipeline<R, W, Handlers...>(
+ false,
+ std::forward<HandlersArgs>(handlersArgs)...),
+ ctx_(this, std::forward<HandlerArg>(handlerArg)) {
+ if (shouldFinalize) {
+ finalize();
+ }
+ }
+
+ public:
+ template <class... HandlersArgs>
+ explicit ChannelPipeline(HandlersArgs&&... handlersArgs)
+ : ChannelPipeline(true, std::forward<HandlersArgs>(handlersArgs)...) {}
+
+ ~ChannelPipeline() {}
+
+ void destroy() override { }
+
+ void read(R msg) {
+ typename ChannelPipeline<R, W>::DestructorGuard dg(
+ static_cast<DelayedDestruction*>(this));
+ front_->read(std::forward<R>(msg));
+ }
+
+ void readEOF() {
+ typename ChannelPipeline<R, W>::DestructorGuard dg(
+ static_cast<DelayedDestruction*>(this));
+ front_->readEOF();
+ }
+
+ void readException(exception_wrapper e) {
+ typename ChannelPipeline<R, W>::DestructorGuard dg(
+ static_cast<DelayedDestruction*>(this));
+ front_->readException(std::move(e));
+ }
+
+ Future<void> write(W msg) {
+ typename ChannelPipeline<R, W>::DestructorGuard dg(
+ static_cast<DelayedDestruction*>(this));
+ return back_->write(std::forward<W>(msg));
+ }
+
+ Future<void> close() {
+ typename ChannelPipeline<R, W>::DestructorGuard dg(
+ static_cast<DelayedDestruction*>(this));
+ return back_->close();
+ }
+
+ void attachTransport(
+ std::shared_ptr<AsyncTransport> transport) {
+ typename ChannelPipeline<R, W>::DestructorGuard dg(
+ static_cast<DelayedDestruction*>(this));
+ CHECK((!ChannelPipeline<R, W>::transport_));
+ ChannelPipeline<R, W, Handlers...>::attachTransport(std::move(transport));
+ forEachCtx([&](PipelineContext* ctx){
+ ctx->attachTransport();
+ });
+ }
+
+ void detachTransport() {
+ typename ChannelPipeline<R, W>::DestructorGuard dg(
+ static_cast<DelayedDestruction*>(this));
+ ChannelPipeline<R, W, Handlers...>::detachTransport();
+ forEachCtx([&](PipelineContext* ctx){
+ ctx->detachTransport();
+ });
+ }
+
+ std::shared_ptr<AsyncTransport> getTransport() {
+ return ChannelPipeline<R, W>::transport_;
+ }
+
+ template <class H>
+ ChannelPipeline& addBack(H&& handler) {
+ ChannelPipeline<R, W>::addBack(std::move(handler));
+ return *this;
+ }
+
+ template <class H>
+ ChannelPipeline& addFront(H&& handler) {
+ ctxs_.insert(
+ ctxs_.begin(),
+ folly::make_unique<ContextImpl<ChannelPipeline, H>>(
+ this,
+ std::move(handler)));
+ return *this;
+ }
+
+ template <class H>
+ H* getHandler(size_t i) {
+ if (i > ctxs_.size()) {
+ return ChannelPipeline<R, W, Handlers...>::template getHandler<H>(
+ i - (ctxs_.size() + 1));
+ } else {
+ auto pctx = (i == ctxs_.size()) ? &ctx_ : ctxs_[i].get();
+ auto ctx = dynamic_cast<ContextImpl<ChannelPipeline, H>*>(pctx);
+ return ctx->getHandler();
+ }
+ }
+
+ void finalize() {
+ finalizeHelper();
+ auto ctx = ctxs_.empty() ? &ctx_ : ctxs_.front().get();
+ front_ = dynamic_cast<InboundChannelHandlerContext<R>*>(ctx);
+ if (!front_) {
+ throw std::invalid_argument("wrong type for first handler");
+ }
+ }
+
+ protected:
+ void finalizeHelper() {
+ ChannelPipeline<R, W, Handlers...>::finalizeHelper();
+ back_ = ChannelPipeline<R, W, Handlers...>::back_;
+ if (!back_) {
+ auto is_end = ChannelPipeline<R, W, Handlers...>::is_end;
+ CHECK(is_end);
+ back_ = dynamic_cast<OutboundChannelHandlerContext<W>*>(&ctx_);
+ if (!back_) {
+ throw std::invalid_argument("wrong type for last handler");
+ }
+ }
+
+ if (!ctxs_.empty()) {
+ for (int i = 0; i < ctxs_.size() - 1; i++) {
+ ctxs_[i]->link(ctxs_[i+1].get());
+ }
+ ctxs_.back()->link(&ctx_);
+ }
+
+ auto nextFront = ChannelPipeline<R, W, Handlers...>::getLocalFront();
+ if (nextFront) {
+ ctx_.link(nextFront);
+ }
+ }
+
+ PipelineContext* getLocalFront() {
+ return ctxs_.empty() ? &ctx_ : ctxs_.front().get();
+ }
+
+ static const bool is_end{false};
+ InboundChannelHandlerContext<R>* front_{nullptr};
+ OutboundChannelHandlerContext<W>* back_{nullptr};
+
+ private:
+ template <class F>
+ void forEachCtx(const F& func) {
+ for (auto& ctx : ctxs_) {
+ func(ctx.get());
+ }
+ func(&ctx_);
+ }
+
+ ContextImpl<ChannelPipeline, Handler> ctx_;
+ std::vector<std::unique_ptr<PipelineContext>> ctxs_;
+};
+
+}}
+
+namespace folly {
+
+class AsyncSocket;
+
+template <typename Pipeline>
+class PipelineFactory {
+ public:
+ virtual Pipeline* newPipeline(std::shared_ptr<AsyncSocket>) = 0;
+ virtual ~PipelineFactory() {}
+};
+
+}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <folly/wangle/channel/ChannelHandler.h>
+#include <folly/io/async/EventBase.h>
+#include <folly/io/async/EventBaseManager.h>
+#include <folly/io/IOBuf.h>
+#include <folly/io/IOBufQueue.h>
+
+namespace folly { namespace wangle {
+
+/*
+ * OutputBufferingHandler buffers writes in order to minimize syscalls. The
+ * transport will be written to once per event loop instead of on every write.
+ */
+class OutputBufferingHandler : public BytesToBytesHandler,
+ protected EventBase::LoopCallback {
+ public:
+ Future<void> write(Context* ctx, std::unique_ptr<IOBuf> buf) override {
+ CHECK(buf);
+ if (!queueSends_) {
+ return ctx->fireWrite(std::move(buf));
+ } else {
+ ctx_ = ctx;
+ // Delay sends to optimize for fewer syscalls
+ if (!sends_) {
+ DCHECK(!isLoopCallbackScheduled());
+ // Buffer all the sends, and call writev once per event loop.
+ sends_ = std::move(buf);
+ ctx->getTransport()->getEventBase()->runInLoop(this);
+ } else {
+ DCHECK(isLoopCallbackScheduled());
+ sends_->prependChain(std::move(buf));
+ }
+ Promise<void> p;
+ auto f = p.getFuture();
+ promises_.push_back(std::move(p));
+ return f;
+ }
+ }
+
+ void runLoopCallback() noexcept override {
+ MoveWrapper<std::vector<Promise<void>>> promises(std::move(promises_));
+ ctx_->fireWrite(std::move(sends_)).then([promises](Try<void>&& t) mutable {
+ try {
+ t.throwIfFailed();
+ for (auto& p : *promises) {
+ p.setValue();
+ }
+ } catch (...) {
+ for (auto& p : *promises) {
+ p.setException(std::current_exception());
+ }
+ }
+ });
+ }
+
+ std::vector<Promise<void>> promises_;
+ std::unique_ptr<IOBuf> sends_{nullptr};
+ bool queueSends_{true};
+ Context* ctx_;
+};
+
+}}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <folly/wangle/channel/ChannelHandler.h>
+#include <folly/wangle/channel/ChannelPipeline.h>
+#include <folly/wangle/channel/AsyncSocketHandler.h>
+#include <folly/wangle/channel/OutputBufferingHandler.h>
+#include <folly/wangle/channel/test/MockChannelHandler.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+using namespace folly;
+using namespace folly::wangle;
+using namespace testing;
+
+typedef StrictMock<MockChannelHandlerAdapter<int, int>> IntHandler;
+typedef ChannelHandlerPtr<IntHandler, false> IntHandlerPtr;
+
+ACTION(FireRead) {
+ arg0->fireRead(arg1);
+}
+
+ACTION(FireReadEOF) {
+ arg0->fireReadEOF();
+}
+
+ACTION(FireReadException) {
+ arg0->fireReadException(arg1);
+}
+
+ACTION(FireWrite) {
+ arg0->fireWrite(arg1);
+}
+
+ACTION(FireClose) {
+ arg0->fireClose();
+}
+
+// Test move only types, among other things
+TEST(ChannelTest, RealHandlersCompile) {
+ EventBase eb;
+ auto socket = AsyncSocket::newSocket(&eb);
+ // static
+ {
+ ChannelPipeline<IOBufQueue&, std::unique_ptr<IOBuf>,
+ AsyncSocketHandler,
+ OutputBufferingHandler>
+ pipeline{AsyncSocketHandler(socket), OutputBufferingHandler()};
+ EXPECT_TRUE(pipeline.getHandler<AsyncSocketHandler>(0));
+ EXPECT_TRUE(pipeline.getHandler<OutputBufferingHandler>(1));
+ }
+ // dynamic
+ {
+ ChannelPipeline<IOBufQueue&, std::unique_ptr<IOBuf>> pipeline;
+ pipeline
+ .addBack(AsyncSocketHandler(socket))
+ .addBack(OutputBufferingHandler())
+ .finalize();
+ EXPECT_TRUE(pipeline.getHandler<AsyncSocketHandler>(0));
+ EXPECT_TRUE(pipeline.getHandler<OutputBufferingHandler>(1));
+ }
+}
+
+// Test that handlers correctly fire the next handler when directed
+TEST(ChannelTest, FireActions) {
+ IntHandler handler1;
+ IntHandler handler2;
+
+ EXPECT_CALL(handler1, attachPipeline(_));
+ EXPECT_CALL(handler2, attachPipeline(_));
+
+ ChannelPipeline<int, int, IntHandlerPtr, IntHandlerPtr>
+ pipeline(&handler1, &handler2);
+
+ EXPECT_CALL(handler1, read_(_, _)).WillOnce(FireRead());
+ EXPECT_CALL(handler2, read_(_, _)).Times(1);
+ pipeline.read(1);
+
+ EXPECT_CALL(handler1, readEOF(_)).WillOnce(FireReadEOF());
+ EXPECT_CALL(handler2, readEOF(_)).Times(1);
+ pipeline.readEOF();
+
+ EXPECT_CALL(handler1, readException(_, _)).WillOnce(FireReadException());
+ EXPECT_CALL(handler2, readException(_, _)).Times(1);
+ pipeline.readException(make_exception_wrapper<std::runtime_error>("blah"));
+
+ EXPECT_CALL(handler2, write_(_, _)).WillOnce(FireWrite());
+ EXPECT_CALL(handler1, write_(_, _)).Times(1);
+ EXPECT_NO_THROW(pipeline.write(1).value());
+
+ EXPECT_CALL(handler2, close_(_)).WillOnce(FireClose());
+ EXPECT_CALL(handler1, close_(_)).Times(1);
+ EXPECT_NO_THROW(pipeline.close().value());
+
+ EXPECT_CALL(handler1, detachPipeline(_));
+ EXPECT_CALL(handler2, detachPipeline(_));
+}
+
+// Test that nothing bad happens when actions reach the end of the pipeline
+// (a warning will be logged, however)
+TEST(ChannelTest, ReachEndOfPipeline) {
+ IntHandler handler;
+ EXPECT_CALL(handler, attachPipeline(_));
+ ChannelPipeline<int, int, IntHandlerPtr>
+ pipeline(&handler);
+
+ EXPECT_CALL(handler, read_(_, _)).WillOnce(FireRead());
+ pipeline.read(1);
+
+ EXPECT_CALL(handler, readEOF(_)).WillOnce(FireReadEOF());
+ pipeline.readEOF();
+
+ EXPECT_CALL(handler, readException(_, _)).WillOnce(FireReadException());
+ pipeline.readException(make_exception_wrapper<std::runtime_error>("blah"));
+
+ EXPECT_CALL(handler, write_(_, _)).WillOnce(FireWrite());
+ EXPECT_NO_THROW(pipeline.write(1).value());
+
+ EXPECT_CALL(handler, close_(_)).WillOnce(FireClose());
+ EXPECT_NO_THROW(pipeline.close().value());
+
+ EXPECT_CALL(handler, detachPipeline(_));
+}
+
+// Test having the last read handler turn around and write
+TEST(ChannelTest, TurnAround) {
+ IntHandler handler1;
+ IntHandler handler2;
+
+ EXPECT_CALL(handler1, attachPipeline(_));
+ EXPECT_CALL(handler2, attachPipeline(_));
+
+ ChannelPipeline<int, int, IntHandlerPtr, IntHandlerPtr>
+ pipeline(&handler1, &handler2);
+
+ EXPECT_CALL(handler1, read_(_, _)).WillOnce(FireRead());
+ EXPECT_CALL(handler2, read_(_, _)).WillOnce(FireWrite());
+ EXPECT_CALL(handler1, write_(_, _)).Times(1);
+ pipeline.read(1);
+
+ EXPECT_CALL(handler1, detachPipeline(_));
+ EXPECT_CALL(handler2, detachPipeline(_));
+}
+
+TEST(ChannelTest, DynamicFireActions) {
+ IntHandler handler1, handler2, handler3;
+ EXPECT_CALL(handler2, attachPipeline(_));
+ ChannelPipeline<int, int, IntHandlerPtr>
+ pipeline(&handler2);
+
+ EXPECT_CALL(handler1, attachPipeline(_));
+ EXPECT_CALL(handler3, attachPipeline(_));
+
+ pipeline
+ .addFront(IntHandlerPtr(&handler1))
+ .addBack(IntHandlerPtr(&handler3))
+ .finalize();
+
+ EXPECT_TRUE(pipeline.getHandler<IntHandlerPtr>(0));
+ EXPECT_TRUE(pipeline.getHandler<IntHandlerPtr>(1));
+ EXPECT_TRUE(pipeline.getHandler<IntHandlerPtr>(2));
+
+ EXPECT_CALL(handler1, read_(_, _)).WillOnce(FireRead());
+ EXPECT_CALL(handler2, read_(_, _)).WillOnce(FireRead());
+ EXPECT_CALL(handler3, read_(_, _)).Times(1);
+ pipeline.read(1);
+
+ EXPECT_CALL(handler3, write_(_, _)).WillOnce(FireWrite());
+ EXPECT_CALL(handler2, write_(_, _)).WillOnce(FireWrite());
+ EXPECT_CALL(handler1, write_(_, _)).Times(1);
+ EXPECT_NO_THROW(pipeline.write(1).value());
+
+ EXPECT_CALL(handler1, detachPipeline(_));
+ EXPECT_CALL(handler2, detachPipeline(_));
+ EXPECT_CALL(handler3, detachPipeline(_));
+}
+
+template <class Rin, class Rout = Rin, class Win = Rout, class Wout = Rin>
+class ConcreteChannelHandler : public ChannelHandler<Rin, Rout, Win, Wout> {
+ typedef typename ChannelHandler<Rin, Rout, Win, Wout>::Context Context;
+ public:
+ void read(Context* ctx, Rin msg) {}
+ Future<void> write(Context* ctx, Win msg) { return makeFuture(); }
+};
+
+typedef ChannelHandlerAdapter<std::string, std::string> StringHandler;
+typedef ConcreteChannelHandler<int, std::string> IntToStringHandler;
+typedef ConcreteChannelHandler<std::string, int> StringToIntHandler;
+
+TEST(ChannelPipeline, DynamicConstruction) {
+ {
+ ChannelPipeline<int, int> pipeline;
+ EXPECT_THROW(
+ pipeline
+ .addBack(ChannelHandlerAdapter<std::string, std::string>{})
+ .finalize(), std::invalid_argument);
+ }
+ {
+ ChannelPipeline<int, int> pipeline;
+ EXPECT_THROW(
+ pipeline
+ .addFront(ChannelHandlerAdapter<std::string, std::string>{})
+ .finalize(),
+ std::invalid_argument);
+ }
+ {
+ ChannelPipeline<std::string, std::string, StringHandler, StringHandler>
+ pipeline{StringHandler(), StringHandler()};
+
+ // Exercise both addFront and addBack. Final pipeline is
+ // StI <-> ItS <-> StS <-> StS <-> StI <-> ItS
+ EXPECT_NO_THROW(
+ pipeline
+ .addFront(IntToStringHandler{})
+ .addFront(StringToIntHandler{})
+ .addBack(StringToIntHandler{})
+ .addBack(IntToStringHandler{})
+ .finalize());
+ }
+}
+
+TEST(ChannelPipeline, AttachTransport) {
+ IntHandler handler;
+ EXPECT_CALL(handler, attachPipeline(_));
+ ChannelPipeline<int, int, IntHandlerPtr>
+ pipeline(&handler);
+
+ EventBase eb;
+ auto socket = AsyncSocket::newSocket(&eb);
+
+ EXPECT_CALL(handler, attachTransport(_));
+ pipeline.attachTransport(socket);
+
+ EXPECT_CALL(handler, detachTransport(_));
+ pipeline.detachTransport();
+
+ EXPECT_CALL(handler, detachPipeline(_));
+}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <folly/wangle/channel/ChannelHandler.h>
+#include <gmock/gmock.h>
+
+namespace folly { namespace wangle {
+
+template <class Rin, class Rout = Rin, class Win = Rout, class Wout = Rin>
+class MockChannelHandler : public ChannelHandler<Rin, Rout, Win, Wout> {
+ public:
+ typedef typename ChannelHandler<Rin, Rout, Win, Wout>::Context Context;
+
+ MockChannelHandler() = default;
+ MockChannelHandler(MockChannelHandler&&) = default;
+
+ MOCK_METHOD2_T(read_, void(Context*, Rin&));
+ MOCK_METHOD1_T(readEOF, void(Context*));
+ MOCK_METHOD2_T(readException, void(Context*, exception_wrapper));
+
+ MOCK_METHOD2_T(write_, void(Context*, Win&));
+ MOCK_METHOD1_T(close_, void(Context*));
+
+ MOCK_METHOD1_T(attachPipeline, void(Context*));
+ MOCK_METHOD1_T(attachTransport, void(Context*));
+ MOCK_METHOD1_T(detachPipeline, void(Context*));
+ MOCK_METHOD1_T(detachTransport, void(Context*));
+
+ void read(Context* ctx, Rin msg) {
+ read_(ctx, msg);
+ }
+
+ Future<void> write(Context* ctx, Win msg) override {
+ return makeFutureTry([&](){
+ write_(ctx, msg);
+ });
+ }
+
+ Future<void> close(Context* ctx) override {
+ return makeFutureTry([&](){
+ close_(ctx);
+ });
+ }
+};
+
+template <class R, class W = R>
+using MockChannelHandlerAdapter = MockChannelHandler<R, R, W, W>;
+
+}}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <folly/wangle/channel/ChannelPipeline.h>
+#include <folly/wangle/channel/OutputBufferingHandler.h>
+#include <folly/wangle/channel/test/MockChannelHandler.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+using namespace folly;
+using namespace folly::wangle;
+using namespace testing;
+
+typedef StrictMock<MockChannelHandlerAdapter<
+ IOBufQueue&,
+ std::unique_ptr<IOBuf>>>
+MockHandler;
+
+MATCHER_P(IOBufContains, str, "") { return arg->moveToFbString() == str; }
+
+TEST(OutputBufferingHandlerTest, Basic) {
+ MockHandler mockHandler;
+ EXPECT_CALL(mockHandler, attachPipeline(_));
+ ChannelPipeline<IOBufQueue&, std::unique_ptr<IOBuf>,
+ ChannelHandlerPtr<MockHandler, false>,
+ OutputBufferingHandler>
+ pipeline(&mockHandler, OutputBufferingHandler{});
+
+ EventBase eb;
+ auto socket = AsyncSocket::newSocket(&eb);
+ EXPECT_CALL(mockHandler, attachTransport(_));
+ pipeline.attachTransport(socket);
+
+ // Buffering should prevent writes until the EB loops, and the writes should
+ // be batched into one write call.
+ auto f1 = pipeline.write(IOBuf::copyBuffer("hello"));
+ auto f2 = pipeline.write(IOBuf::copyBuffer("world"));
+ EXPECT_FALSE(f1.isReady());
+ EXPECT_FALSE(f2.isReady());
+ EXPECT_CALL(mockHandler, write_(_, IOBufContains("helloworld")));
+ eb.loopOnce();
+ EXPECT_TRUE(f1.isReady());
+ EXPECT_TRUE(f2.isReady());
+ EXPECT_CALL(mockHandler, detachPipeline(_));
+}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <glog/logging.h>
+
+namespace folly { namespace wangle {
+
+template <class T>
+class BlockingQueue {
+ public:
+ virtual ~BlockingQueue() {}
+ virtual void add(T item) = 0;
+ virtual void addWithPriority(T item, uint32_t priority) {
+ LOG_FIRST_N(WARNING, 1) <<
+ "add(item, priority) called on a non-priority queue";
+ add(std::move(item));
+ }
+ virtual uint32_t getNumPriorities() {
+ LOG_FIRST_N(WARNING, 1) <<
+ "getNumPriorities() called on a non-priority queue";
+ return 1;
+ }
+ virtual T take() = 0;
+ virtual size_t size() = 0;
+};
+
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <folly/wangle/concurrent/CPUThreadPoolExecutor.h>
+#include <folly/wangle/concurrent/PriorityLifoSemMPMCQueue.h>
+
+namespace folly { namespace wangle {
+
+const size_t CPUThreadPoolExecutor::kDefaultMaxQueueSize = 1 << 18;
+const size_t CPUThreadPoolExecutor::kDefaultNumPriorities = 2;
+
+CPUThreadPoolExecutor::CPUThreadPoolExecutor(
+ size_t numThreads,
+ std::unique_ptr<BlockingQueue<CPUTask>> taskQueue,
+ std::shared_ptr<ThreadFactory> threadFactory)
+ : ThreadPoolExecutor(numThreads, std::move(threadFactory)),
+ taskQueue_(std::move(taskQueue)) {
+ addThreads(numThreads);
+ CHECK(threadList_.get().size() == numThreads);
+}
+
+CPUThreadPoolExecutor::CPUThreadPoolExecutor(
+ size_t numThreads,
+ std::shared_ptr<ThreadFactory> threadFactory)
+ : CPUThreadPoolExecutor(
+ numThreads,
+ folly::make_unique<LifoSemMPMCQueue<CPUTask>>(
+ CPUThreadPoolExecutor::kDefaultMaxQueueSize),
+ std::move(threadFactory)) {}
+
+CPUThreadPoolExecutor::CPUThreadPoolExecutor(size_t numThreads)
+ : CPUThreadPoolExecutor(
+ numThreads,
+ std::make_shared<NamedThreadFactory>("CPUThreadPool")) {}
+
+CPUThreadPoolExecutor::CPUThreadPoolExecutor(
+ size_t numThreads,
+ uint32_t numPriorities,
+ std::shared_ptr<ThreadFactory> threadFactory)
+ : CPUThreadPoolExecutor(
+ numThreads,
+ folly::make_unique<PriorityLifoSemMPMCQueue<CPUTask>>(
+ numPriorities,
+ CPUThreadPoolExecutor::kDefaultMaxQueueSize),
+ std::move(threadFactory)) {}
+
+CPUThreadPoolExecutor::~CPUThreadPoolExecutor() {
+ stop();
+ CHECK(threadsToStop_ == 0);
+}
+
+void CPUThreadPoolExecutor::add(Func func) {
+ add(std::move(func), std::chrono::milliseconds(0));
+}
+
+void CPUThreadPoolExecutor::add(
+ Func func,
+ std::chrono::milliseconds expiration,
+ Func expireCallback) {
+ // TODO handle enqueue failure, here and in other add() callsites
+ taskQueue_->add(
+ CPUTask(std::move(func), expiration, std::move(expireCallback)));
+}
+
+void CPUThreadPoolExecutor::add(Func func, uint32_t priority) {
+ add(std::move(func), priority, std::chrono::milliseconds(0));
+}
+
+void CPUThreadPoolExecutor::add(
+ Func func,
+ uint32_t priority,
+ std::chrono::milliseconds expiration,
+ Func expireCallback) {
+ CHECK(priority < getNumPriorities());
+ taskQueue_->addWithPriority(
+ CPUTask(std::move(func), expiration, std::move(expireCallback)),
+ priority);
+}
+
+uint32_t CPUThreadPoolExecutor::getNumPriorities() const {
+ return taskQueue_->getNumPriorities();
+}
+
+BlockingQueue<CPUThreadPoolExecutor::CPUTask>*
+CPUThreadPoolExecutor::getTaskQueue() {
+ return taskQueue_.get();
+}
+
+void CPUThreadPoolExecutor::threadRun(std::shared_ptr<Thread> thread) {
+ thread->startupBaton.post();
+ while (1) {
+ auto task = taskQueue_->take();
+ if (UNLIKELY(task.poison)) {
+ CHECK(threadsToStop_-- > 0);
+ stoppedThreads_.add(thread);
+ return;
+ } else {
+ runTask(thread, std::move(task));
+ }
+
+ if (UNLIKELY(threadsToStop_ > 0 && !isJoin_)) {
+ if (--threadsToStop_ >= 0) {
+ stoppedThreads_.add(thread);
+ return;
+ } else {
+ threadsToStop_++;
+ }
+ }
+ }
+}
+
+void CPUThreadPoolExecutor::stopThreads(size_t n) {
+ CHECK(stoppedThreads_.size() == 0);
+ threadsToStop_ = n;
+ for (size_t i = 0; i < n; i++) {
+ taskQueue_->add(CPUTask());
+ }
+}
+
+uint64_t CPUThreadPoolExecutor::getPendingTaskCount() {
+ return taskQueue_->size();
+}
+
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <folly/wangle/concurrent/ThreadPoolExecutor.h>
+
+namespace folly { namespace wangle {
+
+class CPUThreadPoolExecutor : public ThreadPoolExecutor {
+ public:
+ struct CPUTask;
+
+ explicit CPUThreadPoolExecutor(
+ size_t numThreads,
+ std::unique_ptr<BlockingQueue<CPUTask>> taskQueue,
+ std::shared_ptr<ThreadFactory> threadFactory =
+ std::make_shared<NamedThreadFactory>("CPUThreadPool"));
+
+ explicit CPUThreadPoolExecutor(size_t numThreads);
+
+ explicit CPUThreadPoolExecutor(
+ size_t numThreads,
+ std::shared_ptr<ThreadFactory> threadFactory);
+
+ explicit CPUThreadPoolExecutor(
+ size_t numThreads,
+ uint32_t numPriorities,
+ std::shared_ptr<ThreadFactory> threadFactory =
+ std::make_shared<NamedThreadFactory>("CPUThreadPool"));
+
+ ~CPUThreadPoolExecutor();
+
+ void add(Func func) override;
+ void add(
+ Func func,
+ std::chrono::milliseconds expiration,
+ Func expireCallback = nullptr) override;
+
+ void add(Func func, uint32_t priority);
+ void add(
+ Func func,
+ uint32_t priority,
+ std::chrono::milliseconds expiration,
+ Func expireCallback = nullptr);
+
+ uint32_t getNumPriorities() const;
+
+ struct CPUTask : public ThreadPoolExecutor::Task {
+ // Must be noexcept move constructible so it can be used in MPMCQueue
+ explicit CPUTask(
+ Func&& f,
+ std::chrono::milliseconds expiration,
+ Func&& expireCallback)
+ : Task(std::move(f), expiration, std::move(expireCallback)),
+ poison(false) {}
+ CPUTask()
+ : Task(nullptr, std::chrono::milliseconds(0), nullptr),
+ poison(true) {}
+ CPUTask(CPUTask&& o) noexcept : Task(std::move(o)), poison(o.poison) {}
+ CPUTask(const CPUTask&) = default;
+ CPUTask& operator=(const CPUTask&) = default;
+ bool poison;
+ };
+
+ static const size_t kDefaultMaxQueueSize;
+ static const size_t kDefaultNumPriorities;
+
+ protected:
+ BlockingQueue<CPUTask>* getTaskQueue();
+
+ private:
+ void threadRun(ThreadPtr thread) override;
+ void stopThreads(size_t n) override;
+ uint64_t getPendingTaskCount() override;
+
+ std::unique_ptr<BlockingQueue<CPUTask>> taskQueue_;
+ std::atomic<ssize_t> threadsToStop_{0};
+};
+
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <folly/wangle/concurrent/Codel.h>
+#include <algorithm>
+#include <math.h>
+
+#ifndef NO_LIB_GFLAGS
+ #include <gflags/gflags.h>
+ DEFINE_int32(codel_interval, 100,
+ "Codel default interval time in ms");
+ DEFINE_int32(codel_target_delay, 5,
+ "Target codel queueing delay in ms");
+#endif
+
+namespace folly { namespace wangle {
+
+#ifdef NO_LIB_GFLAGS
+ int32_t FLAGS_codel_interval = 100;
+ int32_t FLAGS_codel_target_delay = 5;
+#endif
+
+Codel::Codel()
+ : codelMinDelay_(0),
+ codelIntervalTime_(std::chrono::steady_clock::now()),
+ codelResetDelay_(true),
+ overloaded_(false) {}
+
+bool Codel::overloaded(std::chrono::microseconds delay) {
+ bool ret = false;
+ auto now = std::chrono::steady_clock::now();
+
+ // Avoid another thread updating the value at the same time we are using it
+ // to calculate the overloaded state
+ auto minDelay = codelMinDelay_;
+
+ if (now > codelIntervalTime_ &&
+ (!codelResetDelay_.load(std::memory_order_acquire)
+ && !codelResetDelay_.exchange(true))) {
+ codelIntervalTime_ = now + std::chrono::milliseconds(FLAGS_codel_interval);
+
+ if (minDelay > std::chrono::milliseconds(FLAGS_codel_target_delay)) {
+ overloaded_ = true;
+ } else {
+ overloaded_ = false;
+ }
+ }
+ // Care must be taken that only a single thread resets codelMinDelay_,
+ // and that it happens after the interval reset above
+ if (codelResetDelay_.load(std::memory_order_acquire) &&
+ codelResetDelay_.exchange(false)) {
+ codelMinDelay_ = delay;
+ // More than one request must come in during an interval before codel
+ // starts dropping requests
+ return false;
+ } else if(delay < codelMinDelay_) {
+ codelMinDelay_ = delay;
+ }
+
+ if (overloaded_ &&
+ delay > std::chrono::milliseconds(FLAGS_codel_target_delay * 2)) {
+ ret = true;
+ }
+
+ return ret;
+
+}
+
+int Codel::getLoad() {
+ return std::min(100, (int)codelMinDelay_.count() /
+ (2 * FLAGS_codel_target_delay));
+}
+
+int Codel::getMinDelay() {
+ return (int) codelMinDelay_.count();
+}
+
+}} //namespace
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <atomic>
+#include <chrono>
+
+namespace folly { namespace wangle {
+
+/* Codel algorithm implementation:
+ * http://en.wikipedia.org/wiki/CoDel
+ *
+ * Algorithm modified slightly: Instead of changing the interval time
+ * based on the average min delay, instead we use an alternate timeout
+ * for each task if the min delay during the interval period is too
+ * high.
+ *
+ * This was found to have better latency metrics than changing the
+ * window size, since we can communicate with the sender via thrift
+ * instead of only via the tcp window size congestion control, as in TCP.
+ */
+class Codel {
+
+ public:
+ Codel();
+
+ // Given a delay, returns wether the codel algorithm would
+ // reject a queued request with this delay.
+ //
+ // Internally, it also keeps track of the interval
+ bool overloaded(std::chrono::microseconds delay);
+
+ // Get the queue load, as seen by the codel algorithm
+ // Gives a rough guess at how bad the queue delay is.
+ //
+ // Return: 0 = no delay, 100 = At the queueing limit
+ int getLoad();
+
+ int getMinDelay();
+
+ private:
+ std::chrono::microseconds codelMinDelay_;
+ std::chrono::time_point<std::chrono::steady_clock> codelIntervalTime_;
+
+ // flag to make overloaded() thread-safe, since we only want
+ // to reset the delay once per time period
+ std::atomic<bool> codelResetDelay_;
+
+ bool overloaded_;
+};
+
+}} // Namespace
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+#include <folly/wangle/futures/Future.h>
+
+namespace folly { namespace wangle {
+
+template <typename ExecutorImpl>
+class FutureExecutor : public ExecutorImpl {
+ public:
+ template <typename... Args>
+ explicit FutureExecutor(Args&&... args)
+ : ExecutorImpl(std::forward<Args>(args)...) {}
+
+ /*
+ * Given a function func that returns a Future<T>, adds that function to the
+ * contained Executor and returns a Future<T> which will be fulfilled with
+ * func's result once it has been executed.
+ *
+ * For example: auto f = futureExecutor.addFuture([](){
+ * return doAsyncWorkAndReturnAFuture();
+ * });
+ */
+ template <typename F>
+ typename std::enable_if<isFuture<typename std::result_of<F()>::type>::value,
+ typename std::result_of<F()>::type>::type
+ addFuture(F func) {
+ typedef typename std::result_of<F()>::type::value_type T;
+ Promise<T> promise;
+ auto future = promise.getFuture();
+ auto movePromise = folly::makeMoveWrapper(std::move(promise));
+ auto moveFunc = folly::makeMoveWrapper(std::move(func));
+ ExecutorImpl::add([movePromise, moveFunc] () mutable {
+ (*moveFunc)().then([movePromise] (Try<T>&& t) mutable {
+ movePromise->fulfilTry(std::move(t));
+ });
+ });
+ return future;
+ }
+
+ /*
+ * Similar to addFuture above, but takes a func that returns some non-Future
+ * type T.
+ *
+ * For example: auto f = futureExecutor.addFuture([]() {
+ * return 42;
+ * });
+ */
+ template <typename F>
+ typename std::enable_if<!isFuture<typename std::result_of<F()>::type>::value,
+ Future<typename std::result_of<F()>::type>>::type
+ addFuture(F func) {
+ typedef typename std::result_of<F()>::type T;
+ Promise<T> promise;
+ auto future = promise.getFuture();
+ auto movePromise = folly::makeMoveWrapper(std::move(promise));
+ auto moveFunc = folly::makeMoveWrapper(std::move(func));
+ ExecutorImpl::add([movePromise, moveFunc] () mutable {
+ movePromise->fulfil(std::move(*moveFunc));
+ });
+ return future;
+ }
+};
+
+}}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <folly/experimental/Singleton.h>
+#include <folly/wangle/concurrent/IOExecutor.h>
+#include <folly/wangle/concurrent/IOThreadPoolExecutor.h>
+
+using namespace folly;
+using namespace folly::wangle;
+
+namespace {
+
+Singleton<IOThreadPoolExecutor> globalIOThreadPoolSingleton(
+ "GlobalIOThreadPool",
+ [](){
+ return new IOThreadPoolExecutor(
+ sysconf(_SC_NPROCESSORS_ONLN),
+ std::make_shared<NamedThreadFactory>("GlobalIOThreadPool"));
+ });
+
+}
+
+namespace folly { namespace wangle {
+
+IOExecutor* getIOExecutor() {
+ auto singleton = IOExecutor::getSingleton();
+ auto executor = singleton->load();
+ while (!executor) {
+ IOExecutor* nullIOExecutor = nullptr;
+ singleton->compare_exchange_strong(
+ nullIOExecutor,
+ Singleton<IOThreadPoolExecutor>::get("GlobalIOThreadPool"));
+ executor = singleton->load();
+ }
+ return executor;
+}
+
+void setIOExecutor(IOExecutor* executor) {
+ IOExecutor::getSingleton()->store(executor);
+}
+
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+namespace folly { namespace wangle {
+
+class IOExecutor;
+
+// Retrieve the global IOExecutor. If there is none, a default
+// IOThreadPoolExecutor will be constructed and returned.
+IOExecutor* getIOExecutor();
+
+// Set an IOExecutor to be the global IOExecutor which will be returned by
+// subsequent calls to getIOExecutor(). IOExecutors will uninstall themselves
+// as global when they are destructed.
+void setIOExecutor(IOExecutor* executor);
+
+}}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <folly/wangle/concurrent/IOExecutor.h>
+
+#include <folly/experimental/Singleton.h>
+#include <folly/wangle/concurrent/GlobalExecutor.h>
+
+using folly::Singleton;
+using folly::wangle::IOExecutor;
+
+namespace {
+
+Singleton<std::atomic<IOExecutor*>> globalIOExecutorSingleton(
+ "GlobalIOExecutor",
+ [](){
+ return new std::atomic<IOExecutor*>(nullptr);
+ });
+
+}
+
+namespace folly { namespace wangle {
+
+IOExecutor::~IOExecutor() {
+ auto thisCopy = this;
+ try {
+ getSingleton()->compare_exchange_strong(thisCopy, nullptr);
+ } catch (const std::runtime_error& e) {
+ // The global IOExecutor singleton was already destructed so doesn't need to
+ // be restored. Ignore.
+ }
+}
+
+std::atomic<IOExecutor*>* IOExecutor::getSingleton() {
+ return Singleton<std::atomic<IOExecutor*>>::get("GlobalIOExecutor");
+}
+
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <atomic>
+#include <folly/Executor.h>
+
+namespace folly {
+class EventBase;
+}
+
+namespace folly { namespace wangle {
+
+// An IOExecutor is an executor that operates on at least one EventBase. One of
+// these EventBases should be accessible via getEventBase(). The event base
+// returned by a call to getEventBase() is implementation dependent.
+//
+// Note that IOExecutors don't necessarily loop on the base themselves - for
+// instance, EventBase itself is an IOExecutor but doesn't drive itself.
+//
+// Implementations of IOExecutor are eligible to become the global IO executor,
+// returned on every call to getIOExecutor(), via setIOExecutor().
+// These functions are declared in GlobalExecutor.h
+//
+// If getIOExecutor is called and none has been set, a default global
+// IOThreadPoolExecutor will be created and returned.
+class IOExecutor : public virtual Executor {
+ public:
+ virtual ~IOExecutor();
+ virtual EventBase* getEventBase() = 0;
+
+ private:
+ static std::atomic<IOExecutor*>* getSingleton();
+ friend IOExecutor* getIOExecutor();
+ friend void setIOExecutor(IOExecutor* executor);
+};
+
+}}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <folly/wangle/concurrent/IOThreadPoolExecutor.h>
+
+#include <folly/MoveWrapper.h>
+#include <glog/logging.h>
+#include <folly/io/async/EventBaseManager.h>
+
+#include <folly/detail/MemoryIdler.h>
+
+namespace folly { namespace wangle {
+
+using folly::detail::MemoryIdler;
+
+/* Class that will free jemalloc caches and madvise the stack away
+ * if the event loop is unused for some period of time
+ */
+class MemoryIdlerTimeout
+ : public AsyncTimeout , public EventBase::LoopCallback {
+ public:
+ explicit MemoryIdlerTimeout(EventBase* b) : AsyncTimeout(b), base_(b) {}
+
+ virtual void timeoutExpired() noexcept {
+ idled = true;
+ }
+
+ virtual void runLoopCallback() noexcept {
+ if (idled) {
+ MemoryIdler::flushLocalMallocCaches();
+ MemoryIdler::unmapUnusedStack(MemoryIdler::kDefaultStackToRetain);
+
+ idled = false;
+ } else {
+ std::chrono::steady_clock::duration idleTimeout =
+ MemoryIdler::defaultIdleTimeout.load(
+ std::memory_order_acquire);
+
+ idleTimeout = MemoryIdler::getVariationTimeout(idleTimeout);
+
+ scheduleTimeout(std::chrono::duration_cast<std::chrono::milliseconds>(
+ idleTimeout).count());
+ }
+
+ // reschedule this callback for the next event loop.
+ base_->runBeforeLoop(this);
+ }
+ private:
+ EventBase* base_;
+ bool idled{false};
+} ;
+
+IOThreadPoolExecutor::IOThreadPoolExecutor(
+ size_t numThreads,
+ std::shared_ptr<ThreadFactory> threadFactory)
+ : ThreadPoolExecutor(numThreads, std::move(threadFactory)),
+ nextThread_(0) {
+ addThreads(numThreads);
+ CHECK(threadList_.get().size() == numThreads);
+}
+
+IOThreadPoolExecutor::~IOThreadPoolExecutor() {
+ stop();
+}
+
+void IOThreadPoolExecutor::add(Func func) {
+ add(std::move(func), std::chrono::milliseconds(0));
+}
+
+void IOThreadPoolExecutor::add(
+ Func func,
+ std::chrono::milliseconds expiration,
+ Func expireCallback) {
+ RWSpinLock::ReadHolder{&threadListLock_};
+ if (threadList_.get().empty()) {
+ throw std::runtime_error("No threads available");
+ }
+ auto ioThread = pickThread();
+
+ auto moveTask = folly::makeMoveWrapper(
+ Task(std::move(func), expiration, std::move(expireCallback)));
+ auto wrappedFunc = [ioThread, moveTask] () mutable {
+ runTask(ioThread, std::move(*moveTask));
+ ioThread->pendingTasks--;
+ };
+
+ ioThread->pendingTasks++;
+ if (!ioThread->eventBase->runInEventBaseThread(std::move(wrappedFunc))) {
+ ioThread->pendingTasks--;
+ throw std::runtime_error("Unable to run func in event base thread");
+ }
+}
+
+std::shared_ptr<IOThreadPoolExecutor::IOThread>
+IOThreadPoolExecutor::pickThread() {
+ if (*thisThread_) {
+ return *thisThread_;
+ }
+ auto thread = threadList_.get()[nextThread_++ % threadList_.get().size()];
+ return std::static_pointer_cast<IOThread>(thread);
+}
+
+EventBase* IOThreadPoolExecutor::getEventBase() {
+ return pickThread()->eventBase;
+}
+
+std::shared_ptr<ThreadPoolExecutor::Thread>
+IOThreadPoolExecutor::makeThread() {
+ return std::make_shared<IOThread>(this);
+}
+
+void IOThreadPoolExecutor::threadRun(ThreadPtr thread) {
+ const auto ioThread = std::static_pointer_cast<IOThread>(thread);
+ ioThread->eventBase =
+ folly::EventBaseManager::get()->getEventBase();
+ thisThread_.reset(new std::shared_ptr<IOThread>(ioThread));
+
+ auto idler = new MemoryIdlerTimeout(ioThread->eventBase);
+ ioThread->eventBase->runBeforeLoop(idler);
+
+ thread->startupBaton.post();
+ while (ioThread->shouldRun) {
+ ioThread->eventBase->loopForever();
+ }
+ if (isJoin_) {
+ while (ioThread->pendingTasks > 0) {
+ ioThread->eventBase->loopOnce();
+ }
+ }
+ stoppedThreads_.add(ioThread);
+}
+
+// threadListLock_ is writelocked
+void IOThreadPoolExecutor::stopThreads(size_t n) {
+ for (size_t i = 0; i < n; i++) {
+ const auto ioThread = std::static_pointer_cast<IOThread>(
+ threadList_.get()[i]);
+ ioThread->shouldRun = false;
+ ioThread->eventBase->terminateLoopSoon();
+ }
+}
+
+std::vector<EventBase*> IOThreadPoolExecutor::getEventBases() {
+ std::vector<EventBase*> bases;
+ RWSpinLock::ReadHolder{&threadListLock_};
+ for (const auto& thread : threadList_.get()) {
+ auto ioThread = std::static_pointer_cast<IOThread>(thread);
+ bases.push_back(ioThread->eventBase);
+ }
+ return bases;
+}
+
+// threadListLock_ is readlocked
+uint64_t IOThreadPoolExecutor::getPendingTaskCount() {
+ uint64_t count = 0;
+ for (const auto& thread : threadList_.get()) {
+ auto ioThread = std::static_pointer_cast<IOThread>(thread);
+ size_t pendingTasks = ioThread->pendingTasks;
+ if (pendingTasks > 0 && !ioThread->idle) {
+ pendingTasks--;
+ }
+ count += pendingTasks;
+ }
+ return count;
+}
+
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <folly/io/async/EventBase.h>
+#include <folly/wangle/concurrent/IOExecutor.h>
+#include <folly/wangle/concurrent/ThreadPoolExecutor.h>
+
+namespace folly { namespace wangle {
+
+// N.B. For this thread pool, stop() behaves like join() because outstanding
+// tasks belong to the event base and will be executed upon its destruction.
+class IOThreadPoolExecutor : public ThreadPoolExecutor, public IOExecutor {
+ public:
+ explicit IOThreadPoolExecutor(
+ size_t numThreads,
+ std::shared_ptr<ThreadFactory> threadFactory =
+ std::make_shared<NamedThreadFactory>("IOThreadPool"));
+
+ ~IOThreadPoolExecutor();
+
+ void add(Func func) override;
+ void add(
+ Func func,
+ std::chrono::milliseconds expiration,
+ Func expireCallback = nullptr) override;
+
+ EventBase* getEventBase() override;
+
+ std::vector<EventBase*> getEventBases();
+
+ private:
+ struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING IOThread : public Thread {
+ IOThread(IOThreadPoolExecutor* pool)
+ : Thread(pool),
+ shouldRun(true),
+ pendingTasks(0) {};
+ std::atomic<bool> shouldRun;
+ std::atomic<size_t> pendingTasks;
+ EventBase* eventBase;
+ };
+
+ ThreadPtr makeThread() override;
+ std::shared_ptr<IOThread> pickThread();
+ void threadRun(ThreadPtr thread) override;
+ void stopThreads(size_t n) override;
+ uint64_t getPendingTaskCount() override;
+
+ size_t nextThread_;
+ ThreadLocal<std::shared_ptr<IOThread>> thisThread_;
+};
+
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+#include <folly/wangle/concurrent/BlockingQueue.h>
+#include <folly/LifoSem.h>
+#include <folly/MPMCQueue.h>
+
+namespace folly { namespace wangle {
+
+template <class T>
+class LifoSemMPMCQueue : public BlockingQueue<T> {
+ public:
+ explicit LifoSemMPMCQueue(size_t max_capacity) : queue_(max_capacity) {}
+
+ void add(T item) override {
+ if (!queue_.write(std::move(item))) {
+ throw std::runtime_error("LifoSemMPMCQueue full, can't add item");
+ }
+ sem_.post();
+ }
+
+ T take() override {
+ T item;
+ while (!queue_.read(item)) {
+ sem_.wait();
+ }
+ return item;
+ }
+
+ size_t capacity() {
+ return queue_.capacity();
+ }
+
+ size_t size() override {
+ return queue_.size();
+ }
+
+ private:
+ LifoSem sem_;
+ MPMCQueue<T> queue_;
+};
+
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <atomic>
+#include <string>
+#include <thread>
+
+#include <folly/wangle/concurrent/ThreadFactory.h>
+#include <folly/Conv.h>
+#include <folly/Range.h>
+#include <folly/ThreadName.h>
+
+namespace folly { namespace wangle {
+
+class NamedThreadFactory : public ThreadFactory {
+ public:
+ explicit NamedThreadFactory(folly::StringPiece prefix)
+ : prefix_(prefix.str()), suffix_(0) {}
+
+ std::thread newThread(Func&& func) override {
+ auto thread = std::thread(std::move(func));
+ folly::setThreadName(
+ thread.native_handle(),
+ folly::to<std::string>(prefix_, suffix_++));
+ return thread;
+ }
+
+ void setNamePrefix(folly::StringPiece prefix) {
+ prefix_ = prefix.str();
+ }
+
+ std::string getNamePrefix() {
+ return prefix_;
+ }
+
+ private:
+ std::string prefix_;
+ std::atomic<uint64_t> suffix_;
+};
+
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+#include <folly/wangle/concurrent/BlockingQueue.h>
+#include <folly/LifoSem.h>
+#include <folly/MPMCQueue.h>
+
+namespace folly { namespace wangle {
+
+template <class T>
+class PriorityLifoSemMPMCQueue : public BlockingQueue<T> {
+ public:
+ explicit PriorityLifoSemMPMCQueue(uint32_t numPriorities, size_t capacity) {
+ CHECK(numPriorities > 0);
+ queues_.reserve(numPriorities);
+ for (int i = 0; i < numPriorities; i++) {
+ queues_.push_back(MPMCQueue<T>(capacity));
+ }
+ }
+
+ uint32_t getNumPriorities() override {
+ return queues_.size();
+ }
+
+ // Add at lowest priority by default
+ void add(T item) override {
+ addWithPriority(std::move(item), 0);
+ }
+
+ void addWithPriority(T item, uint32_t priority) override {
+ CHECK(priority < queues_.size());
+ if (!queues_[priority].write(std::move(item))) {
+ throw std::runtime_error("LifoSemMPMCQueue full, can't add item");
+ }
+ sem_.post();
+ }
+
+ T take() override {
+ T item;
+ while (true) {
+ for (auto it = queues_.rbegin(); it != queues_.rend(); it++) {
+ if (it->read(item)) {
+ return item;
+ }
+ }
+ sem_.wait();
+ }
+ }
+
+ size_t size() override {
+ size_t size = 0;
+ for (auto& q : queues_) {
+ size += q.size();
+ }
+ return size;
+ }
+
+ private:
+ LifoSem sem_;
+ std::vector<MPMCQueue<T>> queues_;
+};
+
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+#include <folly/Executor.h>
+
+#include <thread>
+
+namespace folly { namespace wangle {
+
+class ThreadFactory {
+ public:
+ virtual ~ThreadFactory() {}
+ virtual std::thread newThread(Func&& func) = 0;
+};
+
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <folly/wangle/concurrent/ThreadPoolExecutor.h>
+
+namespace folly { namespace wangle {
+
+ThreadPoolExecutor::ThreadPoolExecutor(
+ size_t numThreads,
+ std::shared_ptr<ThreadFactory> threadFactory)
+ : threadFactory_(std::move(threadFactory)),
+ taskStatsSubject_(std::make_shared<Subject<TaskStats>>()) {}
+
+ThreadPoolExecutor::~ThreadPoolExecutor() {
+ CHECK(threadList_.get().size() == 0);
+}
+
+ThreadPoolExecutor::Task::Task(
+ Func&& func,
+ std::chrono::milliseconds expiration,
+ Func&& expireCallback)
+ : func_(std::move(func)),
+ expiration_(expiration),
+ expireCallback_(std::move(expireCallback)) {
+ // Assume that the task in enqueued on creation
+ enqueueTime_ = std::chrono::steady_clock::now();
+}
+
+void ThreadPoolExecutor::runTask(
+ const ThreadPtr& thread,
+ Task&& task) {
+ thread->idle = false;
+ auto startTime = std::chrono::steady_clock::now();
+ task.stats_.waitTime = startTime - task.enqueueTime_;
+ if (task.expiration_ > std::chrono::milliseconds(0) &&
+ task.stats_.waitTime >= task.expiration_) {
+ task.stats_.expired = true;
+ if (task.expireCallback_ != nullptr) {
+ task.expireCallback_();
+ }
+ } else {
+ try {
+ task.func_();
+ } catch (const std::exception& e) {
+ LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled " <<
+ typeid(e).name() << " exception: " << e.what();
+ } catch (...) {
+ LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled non-exception "
+ "object";
+ }
+ task.stats_.runTime = std::chrono::steady_clock::now() - startTime;
+ }
+ thread->idle = true;
+ thread->taskStatsSubject->onNext(std::move(task.stats_));
+}
+
+size_t ThreadPoolExecutor::numThreads() {
+ RWSpinLock::ReadHolder{&threadListLock_};
+ return threadList_.get().size();
+}
+
+void ThreadPoolExecutor::setNumThreads(size_t n) {
+ RWSpinLock::WriteHolder{&threadListLock_};
+ const auto current = threadList_.get().size();
+ if (n > current ) {
+ addThreads(n - current);
+ } else if (n < current) {
+ removeThreads(current - n, true);
+ }
+ CHECK(threadList_.get().size() == n);
+}
+
+// threadListLock_ is writelocked
+void ThreadPoolExecutor::addThreads(size_t n) {
+ std::vector<ThreadPtr> newThreads;
+ for (size_t i = 0; i < n; i++) {
+ newThreads.push_back(makeThread());
+ }
+ for (auto& thread : newThreads) {
+ // TODO need a notion of failing to create the thread
+ // and then handling for that case
+ thread->handle = threadFactory_->newThread(
+ std::bind(&ThreadPoolExecutor::threadRun, this, thread));
+ threadList_.add(thread);
+ }
+ for (auto& thread : newThreads) {
+ thread->startupBaton.wait();
+ }
+}
+
+// threadListLock_ is writelocked
+void ThreadPoolExecutor::removeThreads(size_t n, bool isJoin) {
+ CHECK(n <= threadList_.get().size());
+ CHECK(stoppedThreads_.size() == 0);
+ isJoin_ = isJoin;
+ stopThreads(n);
+ for (size_t i = 0; i < n; i++) {
+ auto thread = stoppedThreads_.take();
+ thread->handle.join();
+ threadList_.remove(thread);
+ }
+ CHECK(stoppedThreads_.size() == 0);
+}
+
+void ThreadPoolExecutor::stop() {
+ RWSpinLock::WriteHolder{&threadListLock_};
+ removeThreads(threadList_.get().size(), false);
+ CHECK(threadList_.get().size() == 0);
+}
+
+void ThreadPoolExecutor::join() {
+ RWSpinLock::WriteHolder{&threadListLock_};
+ removeThreads(threadList_.get().size(), true);
+ CHECK(threadList_.get().size() == 0);
+}
+
+ThreadPoolExecutor::PoolStats ThreadPoolExecutor::getPoolStats() {
+ RWSpinLock::ReadHolder{&threadListLock_};
+ ThreadPoolExecutor::PoolStats stats;
+ stats.threadCount = threadList_.get().size();
+ for (auto thread : threadList_.get()) {
+ if (thread->idle) {
+ stats.idleThreadCount++;
+ } else {
+ stats.activeThreadCount++;
+ }
+ }
+ stats.pendingTaskCount = getPendingTaskCount();
+ stats.totalTaskCount = stats.pendingTaskCount + stats.activeThreadCount;
+ return stats;
+}
+
+std::atomic<uint64_t> ThreadPoolExecutor::Thread::nextId(0);
+
+void ThreadPoolExecutor::StoppedThreadQueue::add(
+ ThreadPoolExecutor::ThreadPtr item) {
+ std::lock_guard<std::mutex> guard(mutex_);
+ queue_.push(std::move(item));
+ sem_.post();
+}
+
+ThreadPoolExecutor::ThreadPtr ThreadPoolExecutor::StoppedThreadQueue::take() {
+ while(1) {
+ {
+ std::lock_guard<std::mutex> guard(mutex_);
+ if (queue_.size() > 0) {
+ auto item = std::move(queue_.front());
+ queue_.pop();
+ return item;
+ }
+ }
+ sem_.wait();
+ }
+}
+
+size_t ThreadPoolExecutor::StoppedThreadQueue::size() {
+ std::lock_guard<std::mutex> guard(mutex_);
+ return queue_.size();
+}
+
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+#include <folly/Executor.h>
+#include <folly/wangle/concurrent/LifoSemMPMCQueue.h>
+#include <folly/wangle/concurrent/NamedThreadFactory.h>
+#include <folly/wangle/rx/Observable.h>
+#include <folly/Baton.h>
+#include <folly/Memory.h>
+#include <folly/RWSpinLock.h>
+
+#include <algorithm>
+#include <mutex>
+#include <queue>
+
+#include <glog/logging.h>
+
+namespace folly { namespace wangle {
+
+class ThreadPoolExecutor : public virtual Executor {
+ public:
+ explicit ThreadPoolExecutor(
+ size_t numThreads,
+ std::shared_ptr<ThreadFactory> threadFactory);
+
+ ~ThreadPoolExecutor();
+
+ virtual void add(Func func) override = 0;
+ virtual void add(
+ Func func,
+ std::chrono::milliseconds expiration,
+ Func expireCallback) = 0;
+
+ void setThreadFactory(std::shared_ptr<ThreadFactory> threadFactory) {
+ CHECK(numThreads() == 0);
+ threadFactory_ = std::move(threadFactory);
+ }
+
+ std::shared_ptr<ThreadFactory> getThreadFactory(void) {
+ return threadFactory_;
+ }
+
+ size_t numThreads();
+ void setNumThreads(size_t numThreads);
+ /*
+ * stop() is best effort - there is no guarantee that unexecuted tasks won't
+ * be executed before it returns. Specifically, IOThreadPoolExecutor's stop()
+ * behaves like join().
+ */
+ void stop();
+ void join();
+
+ struct PoolStats {
+ PoolStats() : threadCount(0), idleThreadCount(0), activeThreadCount(0),
+ pendingTaskCount(0), totalTaskCount(0) {}
+ size_t threadCount, idleThreadCount, activeThreadCount;
+ uint64_t pendingTaskCount, totalTaskCount;
+ };
+
+ PoolStats getPoolStats();
+
+ struct TaskStats {
+ TaskStats() : expired(false), waitTime(0), runTime(0) {}
+ bool expired;
+ std::chrono::nanoseconds waitTime;
+ std::chrono::nanoseconds runTime;
+ };
+
+ Subscription<TaskStats> subscribeToTaskStats(
+ const ObserverPtr<TaskStats>& observer) {
+ return taskStatsSubject_->subscribe(observer);
+ }
+
+ protected:
+ // Prerequisite: threadListLock_ writelocked
+ void addThreads(size_t n);
+ // Prerequisite: threadListLock_ writelocked
+ void removeThreads(size_t n, bool isJoin);
+
+ struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING Thread {
+ explicit Thread(ThreadPoolExecutor* pool)
+ : id(nextId++),
+ handle(),
+ idle(true),
+ taskStatsSubject(pool->taskStatsSubject_) {}
+
+ virtual ~Thread() {}
+
+ static std::atomic<uint64_t> nextId;
+ uint64_t id;
+ std::thread handle;
+ bool idle;
+ Baton<> startupBaton;
+ std::shared_ptr<Subject<TaskStats>> taskStatsSubject;
+ };
+
+ typedef std::shared_ptr<Thread> ThreadPtr;
+
+ struct Task {
+ explicit Task(
+ Func&& func,
+ std::chrono::milliseconds expiration,
+ Func&& expireCallback);
+ Func func_;
+ TaskStats stats_;
+ std::chrono::steady_clock::time_point enqueueTime_;
+ std::chrono::milliseconds expiration_;
+ Func expireCallback_;
+ };
+
+ static void runTask(const ThreadPtr& thread, Task&& task);
+
+ // The function that will be bound to pool threads. It must call
+ // thread->startupBaton.post() when it's ready to consume work.
+ virtual void threadRun(ThreadPtr thread) = 0;
+
+ // Stop n threads and put their ThreadPtrs in the threadsStopped_ queue
+ // Prerequisite: threadListLock_ writelocked
+ virtual void stopThreads(size_t n) = 0;
+
+ // Create a suitable Thread struct
+ virtual ThreadPtr makeThread() {
+ return std::make_shared<Thread>(this);
+ }
+
+ // Prerequisite: threadListLock_ readlocked
+ virtual uint64_t getPendingTaskCount() = 0;
+
+ class ThreadList {
+ public:
+ void add(const ThreadPtr& state) {
+ auto it = std::lower_bound(vec_.begin(), vec_.end(), state, compare);
+ vec_.insert(it, state);
+ }
+
+ void remove(const ThreadPtr& state) {
+ auto itPair = std::equal_range(vec_.begin(), vec_.end(), state, compare);
+ CHECK(itPair.first != vec_.end());
+ CHECK(std::next(itPair.first) == itPair.second);
+ vec_.erase(itPair.first);
+ }
+
+ const std::vector<ThreadPtr>& get() const {
+ return vec_;
+ }
+
+ private:
+ static bool compare(const ThreadPtr& ts1, const ThreadPtr& ts2) {
+ return ts1->id < ts2->id;
+ }
+
+ std::vector<ThreadPtr> vec_;
+ };
+
+ class StoppedThreadQueue : public BlockingQueue<ThreadPtr> {
+ public:
+ void add(ThreadPtr item) override;
+ ThreadPtr take() override;
+ size_t size() override;
+
+ private:
+ LifoSem sem_;
+ std::mutex mutex_;
+ std::queue<ThreadPtr> queue_;
+ };
+
+ std::shared_ptr<ThreadFactory> threadFactory_;
+ ThreadList threadList_;
+ RWSpinLock threadListLock_;
+ StoppedThreadQueue stoppedThreads_;
+ std::atomic<bool> isJoin_; // whether the current downsizing is a join
+
+ std::shared_ptr<Subject<TaskStats>> taskStatsSubject_;
+};
+
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <chrono>
+#include <folly/wangle/concurrent/Codel.h>
+#include <gtest/gtest.h>
+#include <thread>
+
+TEST(CodelTest, Basic) {
+ using std::chrono::milliseconds;
+ folly::wangle::Codel c;
+ std::this_thread::sleep_for(milliseconds(110));
+ // This interval is overloaded
+ EXPECT_FALSE(c.overloaded(milliseconds(100)));
+ std::this_thread::sleep_for(milliseconds(90));
+ // At least two requests must happen in an interval before they will fail
+ EXPECT_FALSE(c.overloaded(milliseconds(50)));
+ EXPECT_TRUE(c.overloaded(milliseconds(50)));
+ std::this_thread::sleep_for(milliseconds(110));
+ // Previous interval is overloaded, but 2ms isn't enough to fail
+ EXPECT_FALSE(c.overloaded(milliseconds(2)));
+ std::this_thread::sleep_for(milliseconds(90));
+ // 20 ms > target interval * 2
+ EXPECT_TRUE(c.overloaded(milliseconds(20)));
+}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <folly/wangle/concurrent/GlobalExecutor.h>
+#include <folly/wangle/concurrent/IOExecutor.h>
+
+using namespace folly::wangle;
+
+TEST(GlobalExecutorTest, GlobalIOExecutor) {
+ class DummyExecutor : public IOExecutor {
+ public:
+ void add(folly::Func f) override {
+ count++;
+ }
+ folly::EventBase* getEventBase() override {
+ return nullptr;
+ }
+ int count{0};
+ };
+
+ auto f = [](){};
+
+ // Don't explode, we should create the default global IOExecutor lazily here.
+ getIOExecutor()->add(f);
+
+ {
+ DummyExecutor dummy;
+ setIOExecutor(&dummy);
+ getIOExecutor()->add(f);
+ // Make sure we were properly installed.
+ EXPECT_EQ(1, dummy.count);
+ }
+
+ // Don't explode, we should restore the default global IOExecutor when dummy
+ // is destructed.
+ getIOExecutor()->add(f);
+}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <folly/wangle/concurrent/FutureExecutor.h>
+#include <folly/wangle/concurrent/ThreadPoolExecutor.h>
+#include <folly/wangle/concurrent/CPUThreadPoolExecutor.h>
+#include <folly/wangle/concurrent/IOThreadPoolExecutor.h>
+#include <glog/logging.h>
+#include <gtest/gtest.h>
+
+using namespace folly::wangle;
+using namespace std::chrono;
+
+static folly::Func burnMs(uint64_t ms) {
+ return [ms]() { std::this_thread::sleep_for(milliseconds(ms)); };
+}
+
+template <class TPE>
+static void basic() {
+ // Create and destroy
+ TPE tpe(10);
+}
+
+TEST(ThreadPoolExecutorTest, CPUBasic) {
+ basic<CPUThreadPoolExecutor>();
+}
+
+TEST(IOThreadPoolExecutorTest, IOBasic) {
+ basic<IOThreadPoolExecutor>();
+}
+
+template <class TPE>
+static void resize() {
+ TPE tpe(100);
+ EXPECT_EQ(100, tpe.numThreads());
+ tpe.setNumThreads(50);
+ EXPECT_EQ(50, tpe.numThreads());
+ tpe.setNumThreads(150);
+ EXPECT_EQ(150, tpe.numThreads());
+}
+
+TEST(ThreadPoolExecutorTest, CPUResize) {
+ resize<CPUThreadPoolExecutor>();
+}
+
+TEST(ThreadPoolExecutorTest, IOResize) {
+ resize<IOThreadPoolExecutor>();
+}
+
+template <class TPE>
+static void stop() {
+ TPE tpe(1);
+ std::atomic<int> completed(0);
+ auto f = [&](){
+ burnMs(10)();
+ completed++;
+ };
+ for (int i = 0; i < 1000; i++) {
+ tpe.add(f);
+ }
+ tpe.stop();
+ EXPECT_GT(1000, completed);
+}
+
+// IOThreadPoolExecutor's stop() behaves like join(). Outstanding tasks belong
+// to the event base, will be executed upon its destruction, and cannot be
+// taken back.
+template <>
+void stop<IOThreadPoolExecutor>() {
+ IOThreadPoolExecutor tpe(1);
+ std::atomic<int> completed(0);
+ auto f = [&](){
+ burnMs(10)();
+ completed++;
+ };
+ for (int i = 0; i < 10; i++) {
+ tpe.add(f);
+ }
+ tpe.stop();
+ EXPECT_EQ(10, completed);
+}
+
+TEST(ThreadPoolExecutorTest, CPUStop) {
+ stop<CPUThreadPoolExecutor>();
+}
+
+TEST(ThreadPoolExecutorTest, IOStop) {
+ stop<IOThreadPoolExecutor>();
+}
+
+template <class TPE>
+static void join() {
+ TPE tpe(10);
+ std::atomic<int> completed(0);
+ auto f = [&](){
+ burnMs(1)();
+ completed++;
+ };
+ for (int i = 0; i < 1000; i++) {
+ tpe.add(f);
+ }
+ tpe.join();
+ EXPECT_EQ(1000, completed);
+}
+
+TEST(ThreadPoolExecutorTest, CPUJoin) {
+ join<CPUThreadPoolExecutor>();
+}
+
+TEST(ThreadPoolExecutorTest, IOJoin) {
+ join<IOThreadPoolExecutor>();
+}
+
+template <class TPE>
+static void resizeUnderLoad() {
+ TPE tpe(10);
+ std::atomic<int> completed(0);
+ auto f = [&](){
+ burnMs(1)();
+ completed++;
+ };
+ for (int i = 0; i < 1000; i++) {
+ tpe.add(f);
+ }
+ tpe.setNumThreads(5);
+ tpe.setNumThreads(15);
+ tpe.join();
+ EXPECT_EQ(1000, completed);
+}
+
+TEST(ThreadPoolExecutorTest, CPUResizeUnderLoad) {
+ resizeUnderLoad<CPUThreadPoolExecutor>();
+}
+
+TEST(ThreadPoolExecutorTest, IOResizeUnderLoad) {
+ resizeUnderLoad<IOThreadPoolExecutor>();
+}
+
+template <class TPE>
+static void poolStats() {
+ folly::Baton<> startBaton, endBaton;
+ TPE tpe(1);
+ auto stats = tpe.getPoolStats();
+ EXPECT_EQ(1, stats.threadCount);
+ EXPECT_EQ(1, stats.idleThreadCount);
+ EXPECT_EQ(0, stats.activeThreadCount);
+ EXPECT_EQ(0, stats.pendingTaskCount);
+ EXPECT_EQ(0, stats.totalTaskCount);
+ tpe.add([&](){ startBaton.post(); endBaton.wait(); });
+ tpe.add([&](){});
+ startBaton.wait();
+ stats = tpe.getPoolStats();
+ EXPECT_EQ(1, stats.threadCount);
+ EXPECT_EQ(0, stats.idleThreadCount);
+ EXPECT_EQ(1, stats.activeThreadCount);
+ EXPECT_EQ(1, stats.pendingTaskCount);
+ EXPECT_EQ(2, stats.totalTaskCount);
+ endBaton.post();
+}
+
+TEST(ThreadPoolExecutorTest, CPUPoolStats) {
+ poolStats<CPUThreadPoolExecutor>();
+}
+
+TEST(ThreadPoolExecutorTest, IOPoolStats) {
+ poolStats<IOThreadPoolExecutor>();
+}
+
+template <class TPE>
+static void taskStats() {
+ TPE tpe(1);
+ std::atomic<int> c(0);
+ auto s = tpe.subscribeToTaskStats(
+ Observer<ThreadPoolExecutor::TaskStats>::create(
+ [&](ThreadPoolExecutor::TaskStats stats) {
+ int i = c++;
+ EXPECT_LT(milliseconds(0), stats.runTime);
+ if (i == 1) {
+ EXPECT_LT(milliseconds(0), stats.waitTime);
+ }
+ }));
+ tpe.add(burnMs(10));
+ tpe.add(burnMs(10));
+ tpe.join();
+ EXPECT_EQ(2, c);
+}
+
+TEST(ThreadPoolExecutorTest, CPUTaskStats) {
+ taskStats<CPUThreadPoolExecutor>();
+}
+
+TEST(ThreadPoolExecutorTest, IOTaskStats) {
+ taskStats<IOThreadPoolExecutor>();
+}
+
+template <class TPE>
+static void expiration() {
+ TPE tpe(1);
+ std::atomic<int> statCbCount(0);
+ auto s = tpe.subscribeToTaskStats(
+ Observer<ThreadPoolExecutor::TaskStats>::create(
+ [&](ThreadPoolExecutor::TaskStats stats) {
+ int i = statCbCount++;
+ if (i == 0) {
+ EXPECT_FALSE(stats.expired);
+ } else if (i == 1) {
+ EXPECT_TRUE(stats.expired);
+ } else {
+ FAIL();
+ }
+ }));
+ std::atomic<int> expireCbCount(0);
+ auto expireCb = [&] () { expireCbCount++; };
+ tpe.add(burnMs(10), seconds(60), expireCb);
+ tpe.add(burnMs(10), milliseconds(10), expireCb);
+ tpe.join();
+ EXPECT_EQ(2, statCbCount);
+ EXPECT_EQ(1, expireCbCount);
+}
+
+TEST(ThreadPoolExecutorTest, CPUExpiration) {
+ expiration<CPUThreadPoolExecutor>();
+}
+
+TEST(ThreadPoolExecutorTest, IOExpiration) {
+ expiration<IOThreadPoolExecutor>();
+}
+
+template <typename TPE>
+static void futureExecutor() {
+ FutureExecutor<TPE> fe(2);
+ std::atomic<int> c{0};
+ fe.addFuture([] () { return makeFuture<int>(42); }).then(
+ [&] (Try<int>&& t) {
+ c++;
+ EXPECT_EQ(42, t.value());
+ });
+ fe.addFuture([] () { return 100; }).then(
+ [&] (Try<int>&& t) {
+ c++;
+ EXPECT_EQ(100, t.value());
+ });
+ fe.addFuture([] () { return makeFuture(); }).then(
+ [&] (Try<void>&& t) {
+ c++;
+ EXPECT_NO_THROW(t.value());
+ });
+ fe.addFuture([] () { return; }).then(
+ [&] (Try<void>&& t) {
+ c++;
+ EXPECT_NO_THROW(t.value());
+ });
+ fe.addFuture([] () { throw std::runtime_error("oops"); }).then(
+ [&] (Try<void>&& t) {
+ c++;
+ EXPECT_THROW(t.value(), std::runtime_error);
+ });
+ // Test doing actual async work
+ folly::Baton<> baton;
+ fe.addFuture([&] () {
+ auto p = std::make_shared<Promise<int>>();
+ std::thread t([p](){
+ burnMs(10)();
+ p->setValue(42);
+ });
+ t.detach();
+ return p->getFuture();
+ }).then([&] (Try<int>&& t) {
+ EXPECT_EQ(42, t.value());
+ c++;
+ baton.post();
+ });
+ baton.wait();
+ fe.join();
+ EXPECT_EQ(6, c);
+}
+
+TEST(ThreadPoolExecutorTest, CPUFuturePool) {
+ futureExecutor<CPUThreadPoolExecutor>();
+}
+
+TEST(ThreadPoolExecutorTest, IOFuturePool) {
+ futureExecutor<IOThreadPoolExecutor>();
+}
+
+TEST(ThreadPoolExecutorTest, PriorityPreemptionTest) {
+ bool tookLopri = false;
+ auto completed = 0;
+ auto hipri = [&] {
+ EXPECT_FALSE(tookLopri);
+ completed++;
+ };
+ auto lopri = [&] {
+ tookLopri = true;
+ completed++;
+ };
+ CPUThreadPoolExecutor pool(0, 2);
+ for (int i = 0; i < 50; i++) {
+ pool.add(lopri, 0);
+ }
+ for (int i = 0; i < 50; i++) {
+ pool.add(hipri, 1);
+ }
+ pool.setNumThreads(1);
+ pool.join();
+ EXPECT_EQ(100, completed);
+}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// fbbuild is too dumb to know that .h files in the directory affect
+// our project, unless we have a .cpp file in the target, in the same
+// directory.
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <folly/wangle/rx/Subject.h>
+#include <folly/wangle/rx/Subscription.h>
+#include <folly/wangle/rx/types.h>
+
+#include <folly/RWSpinLock.h>
+#include <folly/SmallLocks.h>
+#include <folly/ThreadLocal.h>
+#include <folly/small_vector.h>
+#include <folly/Executor.h>
+#include <map>
+#include <memory>
+
+namespace folly { namespace wangle {
+
+template <class T, size_t InlineObservers>
+class Observable {
+ public:
+ Observable() : nextSubscriptionId_{1} {}
+
+ // TODO perhaps we want to provide this #5283229
+ Observable(Observable&& other) = delete;
+
+ virtual ~Observable() {
+ if (unsubscriber_) {
+ unsubscriber_->disable();
+ }
+ }
+
+ // The next three methods subscribe the given Observer to this Observable.
+ //
+ // If these are called within an Observer callback, the new observer will not
+ // get the current update but will get subsequent updates.
+ //
+ // subscribe() returns a Subscription object. The observer will continue to
+ // get updates until the Subscription is destroyed.
+ //
+ // observe(ObserverPtr<T>) creates an indefinite subscription
+ //
+ // observe(Observer<T>*) also creates an indefinite subscription, but the
+ // caller is responsible for ensuring that the given Observer outlives this
+ // Observable. This might be useful in high performance environments where
+ // allocations must be kept to a minimum. Template parameter InlineObservers
+ // specifies how many observers can been subscribed inline without any
+ // allocations (it's just the size of a folly::small_vector).
+ virtual Subscription<T> subscribe(ObserverPtr<T> observer) {
+ return subscribeImpl(observer, false);
+ }
+
+ virtual void observe(ObserverPtr<T> observer) {
+ subscribeImpl(observer, true);
+ }
+
+ virtual void observe(Observer<T>* observer) {
+ if (inCallback_ && *inCallback_) {
+ if (!newObservers_) {
+ newObservers_.reset(new ObserverList());
+ }
+ newObservers_->push_back(observer);
+ } else {
+ RWSpinLock::WriteHolder{&observersLock_};
+ observers_.push_back(observer);
+ }
+ }
+
+ // TODO unobserve(ObserverPtr<T>), unobserve(Observer<T>*)
+
+ /// Returns a new Observable that will call back on the given Scheduler.
+ /// The returned Observable must outlive the parent Observable.
+
+ // This and subscribeOn should maybe just be a first-class feature of an
+ // Observable, rather than making new ones whose lifetimes are tied to their
+ // parents. In that case it'd return a reference to this object for
+ // chaining.
+ ObservablePtr<T> observeOn(SchedulerPtr scheduler) {
+ // you're right Hannes, if we have Observable::create we don't need this
+ // helper class.
+ struct ViaSubject : public Observable<T>
+ {
+ ViaSubject(SchedulerPtr sched,
+ Observable* obs)
+ : scheduler_(sched), observable_(obs)
+ {}
+
+ Subscription<T> subscribe(ObserverPtr<T> o) override {
+ return observable_->subscribe(
+ Observer<T>::create(
+ [=](T val) { scheduler_->add([o, val] { o->onNext(val); }); },
+ [=](Error e) { scheduler_->add([o, e] { o->onError(e); }); },
+ [=]() { scheduler_->add([o] { o->onCompleted(); }); }));
+ }
+
+ protected:
+ SchedulerPtr scheduler_;
+ Observable* observable_;
+ };
+
+ return std::make_shared<ViaSubject>(scheduler, this);
+ }
+
+ /// Returns a new Observable that will subscribe to this parent Observable
+ /// via the given Scheduler. This can be subtle and confusing at first, see
+ /// http://www.introtorx.com/Content/v1.0.10621.0/15_SchedulingAndThreading.html#SubscribeOnObserveOn
+ std::unique_ptr<Observable> subscribeOn(SchedulerPtr scheduler) {
+ struct Subject_ : public Subject<T> {
+ public:
+ Subject_(SchedulerPtr s, Observable* o) : scheduler_(s), observable_(o) {
+ }
+
+ Subscription<T> subscribe(ObserverPtr<T> o) {
+ scheduler_->add([=] {
+ observable_->subscribe(o);
+ });
+ return Subscription<T>(nullptr, 0); // TODO
+ }
+
+ protected:
+ SchedulerPtr scheduler_;
+ Observable* observable_;
+ };
+
+ return folly::make_unique<Subject_>(scheduler, this);
+ }
+
+ protected:
+ // Safely execute an operation on each observer. F must take a single
+ // Observer<T>* as its argument.
+ template <class F>
+ void forEachObserver(F f) {
+ if (UNLIKELY(!inCallback_)) {
+ inCallback_.reset(new bool{false});
+ }
+ CHECK(!(*inCallback_));
+ *inCallback_ = true;
+
+ {
+ RWSpinLock::ReadHolder rh(observersLock_);
+ for (auto o : observers_) {
+ f(o);
+ }
+
+ for (auto& kv : subscribers_) {
+ f(kv.second.get());
+ }
+ }
+
+ if (UNLIKELY((newObservers_ && !newObservers_->empty()) ||
+ (newSubscribers_ && !newSubscribers_->empty()) ||
+ (oldSubscribers_ && !oldSubscribers_->empty()))) {
+ {
+ RWSpinLock::WriteHolder wh(observersLock_);
+ if (newObservers_) {
+ for (auto observer : *(newObservers_)) {
+ observers_.push_back(observer);
+ }
+ newObservers_->clear();
+ }
+ if (newSubscribers_) {
+ for (auto& kv : *(newSubscribers_)) {
+ subscribers_.insert(std::move(kv));
+ }
+ newSubscribers_->clear();
+ }
+ if (oldSubscribers_) {
+ for (auto id : *(oldSubscribers_)) {
+ subscribers_.erase(id);
+ }
+ oldSubscribers_->clear();
+ }
+ }
+ }
+ *inCallback_ = false;
+ }
+
+ private:
+ Subscription<T> subscribeImpl(ObserverPtr<T> observer, bool indefinite) {
+ auto subscription = makeSubscription(indefinite);
+ typename SubscriberMap::value_type kv{subscription.id_, std::move(observer)};
+ if (inCallback_ && *inCallback_) {
+ if (!newSubscribers_) {
+ newSubscribers_.reset(new SubscriberMap());
+ }
+ newSubscribers_->insert(std::move(kv));
+ } else {
+ RWSpinLock::WriteHolder{&observersLock_};
+ subscribers_.insert(std::move(kv));
+ }
+ return subscription;
+ }
+
+ class Unsubscriber {
+ public:
+ explicit Unsubscriber(Observable* observable) : observable_(observable) {
+ CHECK(observable_);
+ }
+
+ void unsubscribe(uint64_t id) {
+ CHECK(id > 0);
+ RWSpinLock::ReadHolder guard(lock_);
+ if (observable_) {
+ observable_->unsubscribe(id);
+ }
+ }
+
+ void disable() {
+ RWSpinLock::WriteHolder guard(lock_);
+ observable_ = nullptr;
+ }
+
+ private:
+ RWSpinLock lock_;
+ Observable* observable_;
+ };
+
+ std::shared_ptr<Unsubscriber> unsubscriber_{nullptr};
+ MicroSpinLock unsubscriberLock_{0};
+
+ friend class Subscription<T>;
+
+ void unsubscribe(uint64_t id) {
+ if (inCallback_ && *inCallback_) {
+ if (!oldSubscribers_) {
+ oldSubscribers_.reset(new std::vector<uint64_t>());
+ }
+ if (newSubscribers_) {
+ auto it = newSubscribers_->find(id);
+ if (it != newSubscribers_->end()) {
+ newSubscribers_->erase(it);
+ return;
+ }
+ }
+ oldSubscribers_->push_back(id);
+ } else {
+ RWSpinLock::WriteHolder{&observersLock_};
+ subscribers_.erase(id);
+ }
+ }
+
+ Subscription<T> makeSubscription(bool indefinite) {
+ if (indefinite) {
+ return Subscription<T>(nullptr, nextSubscriptionId_++);
+ } else {
+ if (!unsubscriber_) {
+ std::lock_guard<MicroSpinLock> guard(unsubscriberLock_);
+ if (!unsubscriber_) {
+ unsubscriber_ = std::make_shared<Unsubscriber>(this);
+ }
+ }
+ return Subscription<T>(unsubscriber_, nextSubscriptionId_++);
+ }
+ }
+
+ std::atomic<uint64_t> nextSubscriptionId_;
+ RWSpinLock observersLock_;
+ folly::ThreadLocalPtr<bool> inCallback_;
+
+ typedef folly::small_vector<Observer<T>*, InlineObservers> ObserverList;
+ ObserverList observers_;
+ folly::ThreadLocalPtr<ObserverList> newObservers_;
+
+ typedef std::map<uint64_t, ObserverPtr<T>> SubscriberMap;
+ SubscriberMap subscribers_;
+ folly::ThreadLocalPtr<SubscriberMap> newSubscribers_;
+ folly::ThreadLocalPtr<std::vector<uint64_t>> oldSubscribers_;
+};
+
+}}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <folly/wangle/rx/types.h>
+#include <functional>
+#include <memory>
+#include <stdexcept>
+#include <folly/Memory.h>
+
+namespace folly { namespace wangle {
+
+template <class T> class FunctionObserver;
+
+/// Observer interface. You can subclass it, or you can just use create()
+/// to use std::functions.
+template <class T>
+struct Observer {
+ // These are what it means to be an Observer.
+ virtual void onNext(const T&) = 0;
+ virtual void onError(Error) = 0;
+ virtual void onCompleted() = 0;
+
+ virtual ~Observer() = default;
+
+ /// Create an Observer with std::function callbacks. Handy to make ad-hoc
+ /// Observers with lambdas.
+ ///
+ /// Templated for maximum perfect forwarding flexibility, but ultimately
+ /// whatever you pass in has to implicitly become a std::function for the
+ /// same signature as onNext(), onError(), and onCompleted() respectively.
+ /// (see the FunctionObserver typedefs)
+ template <class N, class E, class C>
+ static std::unique_ptr<Observer> create(
+ N&& onNextFn, E&& onErrorFn, C&& onCompletedFn)
+ {
+ return folly::make_unique<FunctionObserver<T>>(
+ std::forward<N>(onNextFn),
+ std::forward<E>(onErrorFn),
+ std::forward<C>(onCompletedFn));
+ }
+
+ /// Create an Observer with only onNext and onError callbacks.
+ /// onCompleted will just be a no-op.
+ template <class N, class E>
+ static std::unique_ptr<Observer> create(N&& onNextFn, E&& onErrorFn) {
+ return folly::make_unique<FunctionObserver<T>>(
+ std::forward<N>(onNextFn),
+ std::forward<E>(onErrorFn),
+ nullptr);
+ }
+
+ /// Create an Observer with only an onNext callback.
+ /// onError and onCompleted will just be no-ops.
+ template <class N>
+ static std::unique_ptr<Observer> create(N&& onNextFn) {
+ return folly::make_unique<FunctionObserver<T>>(
+ std::forward<N>(onNextFn),
+ nullptr,
+ nullptr);
+ }
+};
+
+/// An observer that uses std::function callbacks. You don't really want to
+/// make one of these directly - instead use the Observer::create() methods.
+template <class T>
+struct FunctionObserver : public Observer<T> {
+ typedef std::function<void(const T&)> OnNext;
+ typedef std::function<void(Error)> OnError;
+ typedef std::function<void()> OnCompleted;
+
+ /// We don't need any fancy overloads of this constructor because that's
+ /// what Observer::create() is for.
+ template <class N = OnNext, class E = OnError, class C = OnCompleted>
+ FunctionObserver(N&& n, E&& e, C&& c)
+ : onNext_(std::forward<N>(n)),
+ onError_(std::forward<E>(e)),
+ onCompleted_(std::forward<C>(c))
+ {}
+
+ void onNext(const T& val) override {
+ if (onNext_) onNext_(val);
+ }
+
+ void onError(Error e) override {
+ if (onError_) onError_(e);
+ }
+
+ void onCompleted() override {
+ if (onCompleted_) onCompleted_();
+ }
+
+ protected:
+ OnNext onNext_;
+ OnError onError_;
+ OnCompleted onCompleted_;
+};
+
+}}
--- /dev/null
+Rx is a pattern for "functional reactive programming" that started at
+Microsoft in C#, and has been reimplemented in various languages, notably
+RxJava for JVM languages.
+
+It is basically the plural of Futures (a la Wangle).
+
+
+ singular | plural
+ +---------------------------------+-----------------------------------
+ sync | Foo getData() | std::vector<Foo> getData()
+ async | wangle::Future<Foo> getData() | wangle::Observable<Foo> getData()
+
+
+For more on Rx, I recommend these resources:
+
+Netflix blog post (RxJava): http://techblog.netflix.com/2013/02/rxjava-netflix-api.html
+Introduction to Rx eBook (C#): http://www.introtorx.com/content/v1.0.10621.0/01_WhyRx.html
+The RxJava wiki: https://github.com/Netflix/RxJava/wiki
+Netflix QCon presentation: http://www.infoq.com/presentations/netflix-functional-rx
+https://rx.codeplex.com/
+
+There are open source C++ implementations, I haven't looked at them. They
+might be the best way to go rather than writing it NIH-style. I mostly did it
+as an exercise, to think through how closely we might want to integrate
+something like this with Wangle, and to get a feel for how it works in C++.
+
+I haven't even tried to support move-only data in this version. I'm on the
+fence about the usage of shared_ptr. Subject is underdeveloped. A whole rich
+set of operations is obviously missing. I haven't decided how to handle
+subscriptions (and therefore cancellation), but I'm pretty sure C#'s
+"Disposable" is thoroughly un-C++ (opposite of RAII). So for now subscribe
+returns nothing at all and you can't cancel anything ever. The whole thing is
+probably riddled with lifetime corner case bugs that will come out like a
+swarm of angry bees as soon as someone tries an infinite sequence, or tries to
+partially observe a long sequence. I'm pretty sure subscribeOn has a bug that
+I haven't tracked down yet.
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <folly/wangle/rx/Observable.h>
+#include <folly/wangle/rx/Observer.h>
+
+namespace folly { namespace wangle {
+
+/// Subject interface. A Subject is both an Observable and an Observer. There
+/// is a default implementation of the Observer methods that just forwards the
+/// observed events to the Subject's observers.
+template <class T>
+struct Subject : public Observable<T>, public Observer<T> {
+ void onNext(const T& val) override {
+ this->forEachObserver([&](Observer<T>* o){
+ o->onNext(val);
+ });
+ }
+ void onError(Error e) override {
+ this->forEachObserver([&](Observer<T>* o){
+ o->onError(e);
+ });
+ }
+ void onCompleted() override {
+ this->forEachObserver([](Observer<T>* o){
+ o->onCompleted();
+ });
+ }
+};
+
+}}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <folly/wangle/rx/Observable.h>
+
+namespace folly { namespace wangle {
+
+template <class T>
+class Subscription {
+ public:
+ Subscription() {}
+
+ Subscription(const Subscription&) = delete;
+
+ Subscription(Subscription&& other) noexcept {
+ *this = std::move(other);
+ }
+
+ Subscription& operator=(Subscription&& other) noexcept {
+ unsubscribe();
+ unsubscriber_ = std::move(other.unsubscriber_);
+ id_ = other.id_;
+ other.unsubscriber_ = nullptr;
+ other.id_ = 0;
+ return *this;
+ }
+
+ ~Subscription() {
+ unsubscribe();
+ }
+
+ private:
+ typedef typename Observable<T>::Unsubscriber Unsubscriber;
+
+ Subscription(std::shared_ptr<Unsubscriber> unsubscriber, uint64_t id)
+ : unsubscriber_(std::move(unsubscriber)), id_(id) {
+ CHECK(id_ > 0);
+ }
+
+ void unsubscribe() {
+ if (unsubscriber_ && id_ > 0) {
+ unsubscriber_->unsubscribe(id_);
+ id_ = 0;
+ unsubscriber_ = nullptr;
+ }
+ }
+
+ std::shared_ptr<Unsubscriber> unsubscriber_;
+ uint64_t id_{0};
+
+ friend class Observable<T>;
+};
+
+}}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <folly/Benchmark.h>
+#include <folly/wangle/rx/Observer.h>
+#include <folly/wangle/rx/Subject.h>
+#include <gflags/gflags.h>
+
+using namespace folly::wangle;
+using folly::BenchmarkSuspender;
+
+static std::unique_ptr<Observer<int>> makeObserver() {
+ return Observer<int>::create([&] (int x) {});
+}
+
+void subscribeImpl(uint iters, int N, bool countUnsubscribe) {
+ for (uint iter = 0; iter < iters; iter++) {
+ BenchmarkSuspender bs;
+ Subject<int> subject;
+ std::vector<std::unique_ptr<Observer<int>>> observers;
+ std::vector<Subscription<int>> subscriptions;
+ subscriptions.reserve(N);
+ for (int i = 0; i < N; i++) {
+ observers.push_back(makeObserver());
+ }
+ bs.dismiss();
+ for (int i = 0; i < N; i++) {
+ subscriptions.push_back(subject.subscribe(std::move(observers[i])));
+ }
+ if (countUnsubscribe) {
+ subscriptions.clear();
+ }
+ bs.rehire();
+ }
+}
+
+void subscribeAndUnsubscribe(uint iters, int N) {
+ subscribeImpl(iters, N, true);
+}
+
+void subscribe(uint iters, int N) {
+ subscribeImpl(iters, N, false);
+}
+
+void observe(uint iters, int N) {
+ for (uint iter = 0; iter < iters; iter++) {
+ BenchmarkSuspender bs;
+ Subject<int> subject;
+ std::vector<std::unique_ptr<Observer<int>>> observers;
+ for (int i = 0; i < N; i++) {
+ observers.push_back(makeObserver());
+ }
+ bs.dismiss();
+ for (int i = 0; i < N; i++) {
+ subject.observe(std::move(observers[i]));
+ }
+ bs.rehire();
+ }
+}
+
+void inlineObserve(uint iters, int N) {
+ for (uint iter = 0; iter < iters; iter++) {
+ BenchmarkSuspender bs;
+ Subject<int> subject;
+ std::vector<Observer<int>*> observers;
+ for (int i = 0; i < N; i++) {
+ observers.push_back(makeObserver().release());
+ }
+ bs.dismiss();
+ for (int i = 0; i < N; i++) {
+ subject.observe(observers[i]);
+ }
+ bs.rehire();
+ for (int i = 0; i < N; i++) {
+ delete observers[i];
+ }
+ }
+}
+
+void notifySubscribers(uint iters, int N) {
+ for (uint iter = 0; iter < iters; iter++) {
+ BenchmarkSuspender bs;
+ Subject<int> subject;
+ std::vector<std::unique_ptr<Observer<int>>> observers;
+ std::vector<Subscription<int>> subscriptions;
+ subscriptions.reserve(N);
+ for (int i = 0; i < N; i++) {
+ observers.push_back(makeObserver());
+ }
+ for (int i = 0; i < N; i++) {
+ subscriptions.push_back(subject.subscribe(std::move(observers[i])));
+ }
+ bs.dismiss();
+ subject.onNext(42);
+ bs.rehire();
+ }
+}
+
+void notifyInlineObservers(uint iters, int N) {
+ for (uint iter = 0; iter < iters; iter++) {
+ BenchmarkSuspender bs;
+ Subject<int> subject;
+ std::vector<Observer<int>*> observers;
+ for (int i = 0; i < N; i++) {
+ observers.push_back(makeObserver().release());
+ }
+ for (int i = 0; i < N; i++) {
+ subject.observe(observers[i]);
+ }
+ bs.dismiss();
+ subject.onNext(42);
+ bs.rehire();
+ }
+}
+
+BENCHMARK_PARAM(subscribeAndUnsubscribe, 1);
+BENCHMARK_RELATIVE_PARAM(subscribe, 1);
+BENCHMARK_RELATIVE_PARAM(observe, 1);
+BENCHMARK_RELATIVE_PARAM(inlineObserve, 1);
+
+BENCHMARK_DRAW_LINE();
+
+BENCHMARK_PARAM(subscribeAndUnsubscribe, 1000);
+BENCHMARK_RELATIVE_PARAM(subscribe, 1000);
+BENCHMARK_RELATIVE_PARAM(observe, 1000);
+BENCHMARK_RELATIVE_PARAM(inlineObserve, 1000);
+
+BENCHMARK_DRAW_LINE();
+
+BENCHMARK_PARAM(notifySubscribers, 1);
+BENCHMARK_RELATIVE_PARAM(notifyInlineObservers, 1);
+
+BENCHMARK_DRAW_LINE();
+
+BENCHMARK_PARAM(notifySubscribers, 1000);
+BENCHMARK_RELATIVE_PARAM(notifyInlineObservers, 1000);
+
+int main(int argc, char** argv) {
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+ folly::runBenchmarks();
+ return 0;
+}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <folly/wangle/rx/Observer.h>
+#include <folly/wangle/rx/Subject.h>
+#include <gtest/gtest.h>
+
+using namespace folly::wangle;
+
+static std::unique_ptr<Observer<int>> incrementer(int& counter) {
+ return Observer<int>::create([&] (int x) {
+ counter++;
+ });
+}
+
+TEST(RxTest, Observe) {
+ Subject<int> subject;
+ auto count = 0;
+ subject.observe(incrementer(count));
+ subject.onNext(1);
+ EXPECT_EQ(1, count);
+}
+
+TEST(RxTest, ObserveInline) {
+ Subject<int> subject;
+ auto count = 0;
+ auto o = incrementer(count).release();
+ subject.observe(o);
+ subject.onNext(1);
+ EXPECT_EQ(1, count);
+ delete o;
+}
+
+TEST(RxTest, Subscription) {
+ Subject<int> subject;
+ auto count = 0;
+ {
+ auto s = subject.subscribe(incrementer(count));
+ subject.onNext(1);
+ }
+ // The subscription has gone out of scope so no one should get this.
+ subject.onNext(2);
+ EXPECT_EQ(1, count);
+}
+
+TEST(RxTest, SubscriptionMove) {
+ Subject<int> subject;
+ auto count = 0;
+ auto s = subject.subscribe(incrementer(count));
+ auto s2 = subject.subscribe(incrementer(count));
+ s2 = std::move(s);
+ subject.onNext(1);
+ Subscription<int> s3(std::move(s2));
+ subject.onNext(2);
+ EXPECT_EQ(2, count);
+}
+
+TEST(RxTest, SubscriptionOutlivesSubject) {
+ Subscription<int> s;
+ {
+ Subject<int> subject;
+ s = subject.subscribe(Observer<int>::create([](int){}));
+ }
+ // Don't explode when s is destroyed
+}
+
+TEST(RxTest, SubscribeDuringCallback) {
+ // A subscriber who was subscribed in the course of a callback should get
+ // subsequent updates but not the current update.
+ Subject<int> subject;
+ int outerCount = 0, innerCount = 0;
+ Subscription<int> s1, s2;
+ s1 = subject.subscribe(Observer<int>::create([&] (int x) {
+ outerCount++;
+ s2 = subject.subscribe(incrementer(innerCount));
+ }));
+ subject.onNext(42);
+ subject.onNext(0xDEADBEEF);
+ EXPECT_EQ(2, outerCount);
+ EXPECT_EQ(1, innerCount);
+}
+
+TEST(RxTest, ObserveDuringCallback) {
+ Subject<int> subject;
+ int outerCount = 0, innerCount = 0;
+ subject.observe(Observer<int>::create([&] (int x) {
+ outerCount++;
+ subject.observe(incrementer(innerCount));
+ }));
+ subject.onNext(42);
+ subject.onNext(0xDEADBEEF);
+ EXPECT_EQ(2, outerCount);
+ EXPECT_EQ(1, innerCount);
+}
+
+TEST(RxTest, ObserveInlineDuringCallback) {
+ Subject<int> subject;
+ int outerCount = 0, innerCount = 0;
+ auto innerO = incrementer(innerCount).release();
+ auto outerO = Observer<int>::create([&] (int x) {
+ outerCount++;
+ subject.observe(innerO);
+ }).release();
+ subject.observe(outerO);
+ subject.onNext(42);
+ subject.onNext(0xDEADBEEF);
+ EXPECT_EQ(2, outerCount);
+ EXPECT_EQ(1, innerCount);
+ delete innerO;
+ delete outerO;
+}
+
+TEST(RxTest, UnsubscribeDuringCallback) {
+ // A subscriber who was unsubscribed in the course of a callback should get
+ // the current update but not subsequent ones
+ Subject<int> subject;
+ int count1 = 0, count2 = 0;
+ auto s1 = subject.subscribe(incrementer(count1));
+ auto s2 = subject.subscribe(Observer<int>::create([&] (int x) {
+ count2++;
+ s1.~Subscription();
+ }));
+ subject.onNext(1);
+ subject.onNext(2);
+ EXPECT_EQ(1, count1);
+ EXPECT_EQ(2, count2);
+}
+
+TEST(RxTest, SubscribeUnsubscribeDuringCallback) {
+ // A subscriber who was subscribed and unsubscribed in the course of a
+ // callback should not get any updates
+ Subject<int> subject;
+ int outerCount = 0, innerCount = 0;
+ auto s2 = subject.subscribe(Observer<int>::create([&] (int x) {
+ outerCount++;
+ auto s2 = subject.subscribe(incrementer(innerCount));
+ }));
+ subject.onNext(1);
+ subject.onNext(2);
+ EXPECT_EQ(2, outerCount);
+ EXPECT_EQ(0, innerCount);
+}
+
+// Move only type
+typedef std::unique_ptr<int> MO;
+static MO makeMO() { return folly::make_unique<int>(1); }
+template <typename T>
+static ObserverPtr<T> makeMOObserver() {
+ return Observer<T>::create([](const T& mo) {
+ EXPECT_EQ(1, *mo);
+ });
+}
+
+TEST(RxTest, MoveOnlyRvalue) {
+ Subject<MO> subject;
+ auto s1 = subject.subscribe(makeMOObserver<MO>());
+ auto s2 = subject.subscribe(makeMOObserver<MO>());
+ auto mo = makeMO();
+ // Can't bind lvalues to rvalue references
+ // subject.onNext(mo);
+ subject.onNext(std::move(mo));
+ subject.onNext(makeMO());
+}
+
+// Copy only type
+struct CO {
+ CO() = default;
+ CO(const CO&) = default;
+ CO(CO&&) = delete;
+};
+
+template <typename T>
+static ObserverPtr<T> makeCOObserver() {
+ return Observer<T>::create([](const T& mo) {});
+}
+
+TEST(RxTest, CopyOnly) {
+ Subject<CO> subject;
+ auto s1 = subject.subscribe(makeCOObserver<CO>());
+ CO co;
+ subject.onNext(co);
+}
--- /dev/null
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <folly/ExceptionWrapper.h>
+#include <folly/Executor.h>
+
+namespace folly { namespace wangle {
+ typedef folly::exception_wrapper Error;
+ // The Executor is basically an rx Scheduler (by design). So just
+ // alias it.
+ typedef std::shared_ptr<folly::Executor> SchedulerPtr;
+
+ template <class T, size_t InlineObservers = 3> struct Observable;
+ template <class T> struct Observer;
+ template <class T> struct Subject;
+
+ template <class T> using ObservablePtr = std::shared_ptr<Observable<T>>;
+ template <class T> using ObserverPtr = std::shared_ptr<Observer<T>>;
+ template <class T> using SubjectPtr = std::shared_ptr<Subject<T>>;
+}}
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+namespace folly {
+
+class ClientHelloExtStats {
+ public:
+ virtual ~ClientHelloExtStats() noexcept {}
+
+ // client hello
+ virtual void recordAbsentHostname() noexcept = 0;
+ virtual void recordMatch() noexcept = 0;
+ virtual void recordNotMatch() noexcept = 0;
+};
+
+}
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <openssl/dh.h>
+
+// The following was auto-generated by
+// openssl dhparam -C 2048
+DH *get_dh2048()
+ {
+ static unsigned char dh2048_p[]={
+ 0xF8,0x87,0xA5,0x15,0x98,0x35,0x20,0x1E,0xF5,0x81,0xE5,0x95,
+ 0x1B,0xE4,0x54,0xEA,0x53,0xF5,0xE7,0x26,0x30,0x03,0x06,0x79,
+ 0x3C,0xC1,0x0B,0xAD,0x3B,0x59,0x3C,0x61,0x13,0x03,0x7B,0x02,
+ 0x70,0xDE,0xC1,0x20,0x11,0x9E,0x94,0x13,0x50,0xF7,0x62,0xFC,
+ 0x99,0x0D,0xC1,0x12,0x6E,0x03,0x95,0xA3,0x57,0xC7,0x3C,0xB8,
+ 0x6B,0x40,0x56,0x65,0x70,0xFB,0x7A,0xE9,0x02,0xEC,0xD2,0xB6,
+ 0x54,0xD7,0x34,0xAD,0x3D,0x9E,0x11,0x61,0x53,0xBE,0xEA,0xB8,
+ 0x17,0x48,0xA8,0xDC,0x70,0xAE,0x65,0x99,0x3F,0x82,0x4C,0xFF,
+ 0x6A,0xC9,0xFA,0xB1,0xFA,0xE4,0x4F,0x5D,0xA4,0x05,0xC2,0x8E,
+ 0x55,0xC0,0xB1,0x1D,0xCC,0x17,0xF3,0xFA,0x65,0xD8,0x6B,0x09,
+ 0x13,0x01,0x2A,0x39,0xF1,0x86,0x73,0xE3,0x7A,0xC8,0xDB,0x7D,
+ 0xDA,0x1C,0xA1,0x2D,0xBA,0x2C,0x00,0x6B,0x2C,0x55,0x28,0x2B,
+ 0xD5,0xF5,0x3C,0x9F,0x50,0xA7,0xB7,0x28,0x9F,0x22,0xD5,0x3A,
+ 0xC4,0x53,0x01,0xC9,0xF3,0x69,0xB1,0x8D,0x01,0x36,0xF8,0xA8,
+ 0x89,0xCA,0x2E,0x72,0xBC,0x36,0x3A,0x42,0xC1,0x06,0xD6,0x0E,
+ 0xCB,0x4D,0x5C,0x1F,0xE4,0xA1,0x17,0xBF,0x55,0x64,0x1B,0xB4,
+ 0x52,0xEC,0x15,0xED,0x32,0xB1,0x81,0x07,0xC9,0x71,0x25,0xF9,
+ 0x4D,0x48,0x3D,0x18,0xF4,0x12,0x09,0x32,0xC4,0x0B,0x7A,0x4E,
+ 0x83,0xC3,0x10,0x90,0x51,0x2E,0xBE,0x87,0xF9,0xDE,0xB4,0xE6,
+ 0x3C,0x29,0xB5,0x32,0x01,0x9D,0x95,0x04,0xBD,0x42,0x89,0xFD,
+ 0x21,0xEB,0xE9,0x88,0x5A,0x27,0xBB,0x31,0xC4,0x26,0x99,0xAB,
+ 0x8C,0xA1,0x76,0xDB,
+ };
+ static unsigned char dh2048_g[]={
+ 0x02,
+ };
+ DH *dh;
+
+ if ((dh=DH_new()) == nullptr) return(nullptr);
+ dh->p=BN_bin2bn(dh2048_p,(int)sizeof(dh2048_p),nullptr);
+ dh->g=BN_bin2bn(dh2048_g,(int)sizeof(dh2048_g),nullptr);
+ if ((dh->p == nullptr) || (dh->g == nullptr))
+ { DH_free(dh); return(nullptr); }
+ return(dh);
+ }
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/wangle/ssl/PasswordInFile.h>
+
+#include <folly/FileUtil.h>
+
+using namespace std;
+
+namespace folly {
+
+PasswordInFile::PasswordInFile(const string& file)
+ : fileName_(file) {
+ folly::readFile(file.c_str(), password_);
+ auto p = password_.find('\0');
+ if (p != std::string::npos) {
+ password_.erase(p);
+ }
+}
+
+PasswordInFile::~PasswordInFile() {
+ OPENSSL_cleanse((char *)password_.data(), password_.length());
+}
+
+}
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/io/async/SSLContext.h> // PasswordCollector
+
+namespace folly {
+
+class PasswordInFile: public folly::PasswordCollector {
+ public:
+ explicit PasswordInFile(const std::string& file);
+ ~PasswordInFile();
+
+ void getPassword(std::string& password, int size) override {
+ password = password_;
+ }
+
+ const char* getPasswordStr() const {
+ return password_.c_str();
+ }
+
+ std::string describe() const override {
+ return fileName_;
+ }
+
+ protected:
+ std::string fileName_;
+ std::string password_;
+};
+
+}
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <chrono>
+#include <cstdint>
+
+namespace folly {
+
+struct SSLCacheOptions {
+ std::chrono::seconds sslCacheTimeout;
+ uint64_t maxSSLCacheSize;
+ uint64_t sslCacheFlushSize;
+};
+
+}
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/io/async/AsyncSSLSocket.h>
+
+namespace folly {
+
+class SSLSessionCacheManager;
+
+/**
+ * Interface to be implemented by providers of external session caches
+ */
+class SSLCacheProvider {
+public:
+ /**
+ * Context saved during an external cache request that is used to
+ * resume the waiting client.
+ */
+ typedef struct {
+ std::string sessionId;
+ SSL_SESSION* session;
+ SSLSessionCacheManager* manager;
+ AsyncSSLSocket* sslSocket;
+ std::unique_ptr<
+ folly::DelayedDestruction::DestructorGuard> guard;
+ } CacheContext;
+
+ virtual ~SSLCacheProvider() {}
+
+ /**
+ * Store a session in the external cache.
+ * @param sessionId Identifier that can be used later to fetch the
+ * session with getAsync()
+ * @param value Serialized session to store
+ * @param expiration Relative expiration time: seconds from now
+ * @return true if the storing of the session is initiated successfully
+ * (though not necessarily completed; the completion may
+ * happen either before or after this method returns), or
+ * false if the storing cannot be initiated due to an error.
+ */
+ virtual bool setAsync(const std::string& sessionId,
+ const std::string& value,
+ std::chrono::seconds expiration) = 0;
+
+ /**
+ * Retrieve a session from the external cache. When done, call
+ * the cache manager's onGetSuccess() or onGetFailure() callback.
+ * @param sessionId Session ID to fetch
+ * @param context Data to pass back to the SSLSessionCacheManager
+ * in the completion callback
+ * @return true if the lookup of the session is initiated successfully
+ * (though not necessarily completed; the completion may
+ * happen either before or after this method returns), or
+ * false if the lookup cannot be initiated due to an error.
+ */
+ virtual bool getAsync(const std::string& sessionId,
+ CacheContext* context) = 0;
+
+};
+
+}
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <string>
+#include <folly/io/async/SSLContext.h>
+#include <vector>
+
+/**
+ * SSLContextConfig helps to describe the configs/options for
+ * a SSL_CTX. For example:
+ *
+ * 1. Filename of X509, private key and its password.
+ * 2. ciphers list
+ * 3. NPN list
+ * 4. Is session cache enabled?
+ * 5. Is it the default X509 in SNI operation?
+ * 6. .... and a few more
+ */
+namespace folly {
+
+struct SSLContextConfig {
+ SSLContextConfig() {}
+ ~SSLContextConfig() {}
+
+ struct CertificateInfo {
+ std::string certPath;
+ std::string keyPath;
+ std::string passwordPath;
+ };
+
+ /**
+ * Helpers to set/add a certificate
+ */
+ void setCertificate(const std::string& certPath,
+ const std::string& keyPath,
+ const std::string& passwordPath) {
+ certificates.clear();
+ addCertificate(certPath, keyPath, passwordPath);
+ }
+
+ void addCertificate(const std::string& certPath,
+ const std::string& keyPath,
+ const std::string& passwordPath) {
+ certificates.emplace_back(CertificateInfo{certPath, keyPath, passwordPath});
+ }
+
+ /**
+ * Set the optional list of protocols to advertise via TLS
+ * Next Protocol Negotiation. An empty list means NPN is not enabled.
+ */
+ void setNextProtocols(const std::list<std::string>& inNextProtocols) {
+ nextProtocols.clear();
+ nextProtocols.push_back({1, inNextProtocols});
+ }
+
+ typedef std::function<bool(char const* server_name)> SNINoMatchFn;
+
+ std::vector<CertificateInfo> certificates;
+ folly::SSLContext::SSLVersion sslVersion{
+ folly::SSLContext::TLSv1};
+ bool sessionCacheEnabled{true};
+ bool sessionTicketEnabled{true};
+ bool clientHelloParsingEnabled{false};
+ std::string sslCiphers{
+ "ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:"
+ "ECDHE-ECDSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES128-GCM-SHA256:"
+ "ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-SHA:ECDHE-RSA-AES256-SHA:"
+ "AES128-GCM-SHA256:AES256-GCM-SHA384:AES128-SHA:AES256-SHA:"
+ "ECDHE-ECDSA-RC4-SHA:ECDHE-RSA-RC4-SHA:RC4-SHA:RC4-MD5:"
+ "ECDHE-RSA-DES-CBC3-SHA:DES-CBC3-SHA"};
+ std::string eccCurveName;
+ // Ciphers to negotiate if TLS version >= 1.1
+ std::string tls11Ciphers{""};
+ // Weighted lists of NPN strings to advertise
+ std::list<folly::SSLContext::NextProtocolsItem>
+ nextProtocols;
+ bool isLocalPrivateKey{true};
+ // Should this SSLContextConfig be the default for SNI purposes
+ bool isDefault{false};
+ // Callback function to invoke when there are no matching certificates
+ // (will only be invoked once)
+ SNINoMatchFn sniNoMatchFn;
+ // File containing trusted CA's to validate client certificates
+ std::string clientCAFile;
+};
+
+}
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/wangle/ssl/SSLContextManager.h>
+
+#include <folly/wangle/ssl/ClientHelloExtStats.h>
+#include <folly/wangle/ssl/DHParam.h>
+#include <folly/wangle/ssl/PasswordInFile.h>
+#include <folly/wangle/ssl/SSLCacheOptions.h>
+#include <folly/wangle/ssl/SSLSessionCacheManager.h>
+#include <folly/wangle/ssl/SSLUtil.h>
+#include <folly/wangle/ssl/TLSTicketKeyManager.h>
+#include <folly/wangle/ssl/TLSTicketKeySeeds.h>
+
+#include <folly/Conv.h>
+#include <folly/ScopeGuard.h>
+#include <folly/String.h>
+#include <functional>
+#include <openssl/asn1.h>
+#include <openssl/ssl.h>
+#include <string>
+#include <folly/io/async/EventBase.h>
+
+#define OPENSSL_MISSING_FEATURE(name) \
+do { \
+ throw std::runtime_error("missing " #name " support in openssl"); \
+} while(0)
+
+
+using std::string;
+using std::shared_ptr;
+
+/**
+ * SSLContextManager helps to create and manage all SSL_CTX,
+ * SSLSessionCacheManager and TLSTicketManager for a listening
+ * VIP:PORT. (Note, in SNI, a listening VIP:PORT can have >1 SSL_CTX(s)).
+ *
+ * Other responsibilities:
+ * 1. It also handles the SSL_CTX selection after getting the tlsext_hostname
+ * in the client hello message.
+ *
+ * Usage:
+ * 1. Each listening VIP:PORT serving SSL should have one SSLContextManager.
+ * It maps to Acceptor in the wangle vocabulary.
+ *
+ * 2. Create a SSLContextConfig object (e.g. by parsing the JSON config).
+ *
+ * 3. Call SSLContextManager::addSSLContextConfig() which will
+ * then create and configure the SSL_CTX
+ *
+ * Note: Each Acceptor, with SSL support, should have one SSLContextManager to
+ * manage all SSL_CTX for the VIP:PORT.
+ */
+
+namespace folly {
+
+namespace {
+
+X509* getX509(SSL_CTX* ctx) {
+ SSL* ssl = SSL_new(ctx);
+ SSL_set_connect_state(ssl);
+ X509* x509 = SSL_get_certificate(ssl);
+ CRYPTO_add(&x509->references, 1, CRYPTO_LOCK_X509);
+ SSL_free(ssl);
+ return x509;
+}
+
+void set_key_from_curve(SSL_CTX* ctx, const std::string& curveName) {
+#if OPENSSL_VERSION_NUMBER >= 0x0090800fL
+#ifndef OPENSSL_NO_ECDH
+ EC_KEY* ecdh = nullptr;
+ int nid;
+
+ /*
+ * Elliptic-Curve Diffie-Hellman parameters are either "named curves"
+ * from RFC 4492 section 5.1.1, or explicitly described curves over
+ * binary fields. OpenSSL only supports the "named curves", which provide
+ * maximum interoperability.
+ */
+
+ nid = OBJ_sn2nid(curveName.c_str());
+ if (nid == 0) {
+ LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
+ return;
+ }
+ ecdh = EC_KEY_new_by_curve_name(nid);
+ if (ecdh == nullptr) {
+ LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
+ return;
+ }
+
+ SSL_CTX_set_tmp_ecdh(ctx, ecdh);
+ EC_KEY_free(ecdh);
+#endif
+#endif
+}
+
+// Helper to create TLSTicketKeyManger and aware of the needed openssl
+// version/feature.
+std::unique_ptr<TLSTicketKeyManager> createTicketManagerHelper(
+ std::shared_ptr<folly::SSLContext> ctx,
+ const TLSTicketKeySeeds* ticketSeeds,
+ const SSLContextConfig& ctxConfig,
+ SSLStats* stats) {
+
+ std::unique_ptr<TLSTicketKeyManager> ticketManager;
+#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB
+ if (ticketSeeds && ctxConfig.sessionTicketEnabled) {
+ ticketManager = folly::make_unique<TLSTicketKeyManager>(ctx.get(), stats);
+ ticketManager->setTLSTicketKeySeeds(
+ ticketSeeds->oldSeeds,
+ ticketSeeds->currentSeeds,
+ ticketSeeds->newSeeds);
+ } else {
+ ctx->setOptions(SSL_OP_NO_TICKET);
+ }
+#else
+ if (ticketSeeds && ctxConfig.sessionTicketEnabled) {
+ OPENSSL_MISSING_FEATURE(TLSTicket);
+ }
+#endif
+ return ticketManager;
+}
+
+std::string flattenList(const std::list<std::string>& list) {
+ std::string s;
+ bool first = true;
+ for (auto& item : list) {
+ if (first) {
+ first = false;
+ } else {
+ s.append(", ");
+ }
+ s.append(item);
+ }
+ return s;
+}
+
+}
+
+SSLContextManager::~SSLContextManager() {}
+
+SSLContextManager::SSLContextManager(
+ EventBase* eventBase,
+ const std::string& vipName,
+ bool strict,
+ SSLStats* stats) :
+ stats_(stats),
+ eventBase_(eventBase),
+ strict_(strict) {
+}
+
+void SSLContextManager::addSSLContextConfig(
+ const SSLContextConfig& ctxConfig,
+ const SSLCacheOptions& cacheOptions,
+ const TLSTicketKeySeeds* ticketSeeds,
+ const folly::SocketAddress& vipAddress,
+ const std::shared_ptr<SSLCacheProvider>& externalCache) {
+
+ unsigned numCerts = 0;
+ std::string commonName;
+ std::string lastCertPath;
+ std::unique_ptr<std::list<std::string>> subjectAltName;
+ auto sslCtx = std::make_shared<SSLContext>(ctxConfig.sslVersion);
+ for (const auto& cert : ctxConfig.certificates) {
+ try {
+ sslCtx->loadCertificate(cert.certPath.c_str());
+ } catch (const std::exception& ex) {
+ // The exception isn't very useful without the certificate path name,
+ // so throw a new exception that includes the path to the certificate.
+ string msg = folly::to<string>("error loading SSL certificate ",
+ cert.certPath, ": ",
+ folly::exceptionStr(ex));
+ LOG(ERROR) << msg;
+ throw std::runtime_error(msg);
+ }
+
+ // Verify that the Common Name and (if present) Subject Alternative Names
+ // are the same for all the certs specified for the SSL context.
+ numCerts++;
+ X509* x509 = getX509(sslCtx->getSSLCtx());
+ auto guard = folly::makeGuard([x509] { X509_free(x509); });
+ auto cn = SSLUtil::getCommonName(x509);
+ if (!cn) {
+ throw std::runtime_error(folly::to<string>("Cannot get CN for X509 ",
+ cert.certPath));
+ }
+ auto altName = SSLUtil::getSubjectAltName(x509);
+ VLOG(2) << "cert " << cert.certPath << " CN: " << *cn;
+ if (altName) {
+ altName->sort();
+ VLOG(2) << "cert " << cert.certPath << " SAN: " << flattenList(*altName);
+ } else {
+ VLOG(2) << "cert " << cert.certPath << " SAN: " << "{none}";
+ }
+ if (numCerts == 1) {
+ commonName = *cn;
+ subjectAltName = std::move(altName);
+ } else {
+ if (commonName != *cn) {
+ throw std::runtime_error(folly::to<string>("X509 ", cert.certPath,
+ " does not have same CN as ",
+ lastCertPath));
+ }
+ if (altName == nullptr) {
+ if (subjectAltName != nullptr) {
+ throw std::runtime_error(folly::to<string>("X509 ", cert.certPath,
+ " does not have same SAN as ",
+ lastCertPath));
+ }
+ } else {
+ if ((subjectAltName == nullptr) || (*altName != *subjectAltName)) {
+ throw std::runtime_error(folly::to<string>("X509 ", cert.certPath,
+ " does not have same SAN as ",
+ lastCertPath));
+ }
+ }
+ }
+ lastCertPath = cert.certPath;
+
+ // TODO t4438250 - Add ECDSA support to the crypto_ssl offload server
+ // so we can avoid storing the ECDSA private key in the
+ // address space of the Internet-facing process. For
+ // now, if cert name includes "-EC" to denote elliptic
+ // curve, we load its private key even if the server as
+ // a whole has been configured for async crypto.
+ if (ctxConfig.isLocalPrivateKey ||
+ (cert.certPath.find("-EC") != std::string::npos)) {
+ // The private key lives in the same process
+
+ // This needs to be called before loadPrivateKey().
+ if (!cert.passwordPath.empty()) {
+ auto sslPassword = std::make_shared<PasswordInFile>(cert.passwordPath);
+ sslCtx->passwordCollector(sslPassword);
+ }
+
+ try {
+ sslCtx->loadPrivateKey(cert.keyPath.c_str());
+ } catch (const std::exception& ex) {
+ // Throw an error that includes the key path, so the user can tell
+ // which key had a problem.
+ string msg = folly::to<string>("error loading private SSL key ",
+ cert.keyPath, ": ",
+ folly::exceptionStr(ex));
+ LOG(ERROR) << msg;
+ throw std::runtime_error(msg);
+ }
+ }
+ }
+ if (!ctxConfig.isLocalPrivateKey) {
+ enableAsyncCrypto(sslCtx);
+ }
+
+ // Let the server pick the highest performing cipher from among the client's
+ // choices.
+ //
+ // Let's use a unique private key for all DH key exchanges.
+ //
+ // Because some old implementations choke on empty fragments, most SSL
+ // applications disable them (it's part of SSL_OP_ALL). This
+ // will improve performance and decrease write buffer fragmentation.
+ sslCtx->setOptions(SSL_OP_CIPHER_SERVER_PREFERENCE |
+ SSL_OP_SINGLE_DH_USE |
+ SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS);
+
+ // Configure SSL ciphers list
+ if (!ctxConfig.tls11Ciphers.empty()) {
+ // FIXME: create a dummy SSL_CTX for cipher testing purpose? It can
+ // remove the ordering dependency
+
+ // Test to see if the specified TLS1.1 ciphers are valid. Note that
+ // these will be overwritten by the ciphers() call below.
+ sslCtx->setCiphersOrThrow(ctxConfig.tls11Ciphers);
+ }
+
+ // Important that we do this *after* checking the TLS1.1 ciphers above,
+ // since we test their validity by actually setting them.
+ sslCtx->ciphers(ctxConfig.sslCiphers);
+
+ // Use a fix DH param
+ DH* dh = get_dh2048();
+ SSL_CTX_set_tmp_dh(sslCtx->getSSLCtx(), dh);
+ DH_free(dh);
+
+ const string& curve = ctxConfig.eccCurveName;
+ if (!curve.empty()) {
+ set_key_from_curve(sslCtx->getSSLCtx(), curve);
+ }
+
+ if (!ctxConfig.clientCAFile.empty()) {
+ try {
+ sslCtx->setVerificationOption(SSLContext::VERIFY_REQ_CLIENT_CERT);
+ sslCtx->loadTrustedCertificates(ctxConfig.clientCAFile.c_str());
+ sslCtx->loadClientCAList(ctxConfig.clientCAFile.c_str());
+ } catch (const std::exception& ex) {
+ string msg = folly::to<string>("error loading client CA",
+ ctxConfig.clientCAFile, ": ",
+ folly::exceptionStr(ex));
+ LOG(ERROR) << msg;
+ throw std::runtime_error(msg);
+ }
+ }
+
+ // - start - SSL session cache config
+ // the internal cache never does what we want (per-thread-per-vip).
+ // Disable it. SSLSessionCacheManager will set it appropriately.
+ SSL_CTX_set_session_cache_mode(sslCtx->getSSLCtx(), SSL_SESS_CACHE_OFF);
+ SSL_CTX_set_timeout(sslCtx->getSSLCtx(),
+ cacheOptions.sslCacheTimeout.count());
+ std::unique_ptr<SSLSessionCacheManager> sessionCacheManager;
+ if (ctxConfig.sessionCacheEnabled &&
+ cacheOptions.maxSSLCacheSize > 0 &&
+ cacheOptions.sslCacheFlushSize > 0) {
+ sessionCacheManager =
+ folly::make_unique<SSLSessionCacheManager>(
+ cacheOptions.maxSSLCacheSize,
+ cacheOptions.sslCacheFlushSize,
+ sslCtx.get(),
+ vipAddress,
+ commonName,
+ eventBase_,
+ stats_,
+ externalCache);
+ }
+ // - end - SSL session cache config
+
+ std::unique_ptr<TLSTicketKeyManager> ticketManager =
+ createTicketManagerHelper(sslCtx, ticketSeeds, ctxConfig, stats_);
+
+ // finalize sslCtx setup by the individual features supported by openssl
+ ctxSetupByOpensslFeature(sslCtx, ctxConfig);
+
+ try {
+ insert(sslCtx,
+ std::move(sessionCacheManager),
+ std::move(ticketManager),
+ ctxConfig.isDefault);
+ } catch (const std::exception& ex) {
+ string msg = folly::to<string>("Error adding certificate : ",
+ folly::exceptionStr(ex));
+ LOG(ERROR) << msg;
+ throw std::runtime_error(msg);
+ }
+
+}
+
+#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK
+SSLContext::ServerNameCallbackResult
+SSLContextManager::serverNameCallback(SSL* ssl) {
+ shared_ptr<SSLContext> ctx;
+
+ const char* sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
+ if (!sn) {
+ VLOG(6) << "Server Name (tlsext_hostname) is missing";
+ if (clientHelloTLSExtStats_) {
+ clientHelloTLSExtStats_->recordAbsentHostname();
+ }
+ return SSLContext::SERVER_NAME_NOT_FOUND;
+ }
+ size_t snLen = strlen(sn);
+ VLOG(6) << "Server Name (SNI TLS extension): '" << sn << "' ";
+
+ // FIXME: This code breaks the abstraction. Suggestion?
+ AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl);
+ CHECK(sslSocket);
+
+ DNString dnstr(sn, snLen);
+
+ uint32_t count = 0;
+ do {
+ // Try exact match first
+ ctx = getSSLCtx(dnstr);
+ if (ctx) {
+ sslSocket->switchServerSSLContext(ctx);
+ if (clientHelloTLSExtStats_) {
+ clientHelloTLSExtStats_->recordMatch();
+ }
+ return SSLContext::SERVER_NAME_FOUND;
+ }
+
+ ctx = getSSLCtxBySuffix(dnstr);
+ if (ctx) {
+ sslSocket->switchServerSSLContext(ctx);
+ if (clientHelloTLSExtStats_) {
+ clientHelloTLSExtStats_->recordMatch();
+ }
+ return SSLContext::SERVER_NAME_FOUND;
+ }
+
+ // Give the noMatchFn one chance to add the correct cert
+ }
+ while (count++ == 0 && noMatchFn_ && noMatchFn_(sn));
+
+ VLOG(6) << folly::stringPrintf("Cannot find a SSL_CTX for \"%s\"", sn);
+
+ if (clientHelloTLSExtStats_) {
+ clientHelloTLSExtStats_->recordNotMatch();
+ }
+ return SSLContext::SERVER_NAME_NOT_FOUND;
+}
+#endif
+
+// Consolidate all SSL_CTX setup which depends on openssl version/feature
+void
+SSLContextManager::ctxSetupByOpensslFeature(
+ shared_ptr<folly::SSLContext> sslCtx,
+ const SSLContextConfig& ctxConfig) {
+ // Disable compression - profiling shows this to be very expensive in
+ // terms of CPU and memory consumption.
+ //
+#ifdef SSL_OP_NO_COMPRESSION
+ sslCtx->setOptions(SSL_OP_NO_COMPRESSION);
+#endif
+
+ // Enable early release of SSL buffers to reduce the memory footprint
+#ifdef SSL_MODE_RELEASE_BUFFERS
+ sslCtx->getSSLCtx()->mode |= SSL_MODE_RELEASE_BUFFERS;
+#endif
+#ifdef SSL_MODE_EARLY_RELEASE_BBIO
+ sslCtx->getSSLCtx()->mode |= SSL_MODE_EARLY_RELEASE_BBIO;
+#endif
+
+ // This number should (probably) correspond to HTTPSession::kMaxReadSize
+ // For now, this number must also be large enough to accommodate our
+ // largest certificate, because some older clients (IE6/7) require the
+ // cert to be in a single fragment.
+#ifdef SSL_CTRL_SET_MAX_SEND_FRAGMENT
+ SSL_CTX_set_max_send_fragment(sslCtx->getSSLCtx(), 8000);
+#endif
+
+ // Specify cipher(s) to be used for TLS1.1 client
+ if (!ctxConfig.tls11Ciphers.empty()) {
+#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK
+ // Specified TLS1.1 ciphers are valid
+ sslCtx->addClientHelloCallback(
+ std::bind(
+ &SSLContext::switchCiphersIfTLS11,
+ sslCtx.get(),
+ std::placeholders::_1,
+ ctxConfig.tls11Ciphers
+ )
+ );
+#else
+ OPENSSL_MISSING_FEATURE(SNI);
+#endif
+ }
+
+ // NPN (Next Protocol Negotiation)
+ if (!ctxConfig.nextProtocols.empty()) {
+#ifdef OPENSSL_NPN_NEGOTIATED
+ sslCtx->setRandomizedAdvertisedNextProtocols(ctxConfig.nextProtocols);
+#else
+ OPENSSL_MISSING_FEATURE(NPN);
+#endif
+ }
+
+ // SNI
+#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK
+ noMatchFn_ = ctxConfig.sniNoMatchFn;
+ if (ctxConfig.isDefault) {
+ if (defaultCtx_) {
+ throw std::runtime_error(">1 X509 is set as default");
+ }
+
+ defaultCtx_ = sslCtx;
+ defaultCtx_->setServerNameCallback(
+ std::bind(&SSLContextManager::serverNameCallback, this,
+ std::placeholders::_1));
+ }
+#else
+ if (ctxs_.size() > 1) {
+ OPENSSL_MISSING_FEATURE(SNI);
+ }
+#endif
+}
+
+void
+SSLContextManager::insert(shared_ptr<SSLContext> sslCtx,
+ std::unique_ptr<SSLSessionCacheManager> smanager,
+ std::unique_ptr<TLSTicketKeyManager> tmanager,
+ bool defaultFallback) {
+ X509* x509 = getX509(sslCtx->getSSLCtx());
+ auto guard = folly::makeGuard([x509] { X509_free(x509); });
+ auto cn = SSLUtil::getCommonName(x509);
+ if (!cn) {
+ throw std::runtime_error("Cannot get CN");
+ }
+
+ /**
+ * Some notes from RFC 2818. Only for future quick references in case of bugs
+ *
+ * RFC 2818 section 3.1:
+ * "......
+ * If a subjectAltName extension of type dNSName is present, that MUST
+ * be used as the identity. Otherwise, the (most specific) Common Name
+ * field in the Subject field of the certificate MUST be used. Although
+ * the use of the Common Name is existing practice, it is deprecated and
+ * Certification Authorities are encouraged to use the dNSName instead.
+ * ......
+ * In some cases, the URI is specified as an IP address rather than a
+ * hostname. In this case, the iPAddress subjectAltName must be present
+ * in the certificate and must exactly match the IP in the URI.
+ * ......"
+ */
+
+ // Not sure if we ever get this kind of X509...
+ // If we do, assume '*' is always in the CN and ignore all subject alternative
+ // names.
+ if (cn->length() == 1 && (*cn)[0] == '*') {
+ if (!defaultFallback) {
+ throw std::runtime_error("STAR X509 is not the default");
+ }
+ ctxs_.emplace_back(sslCtx);
+ sessionCacheManagers_.emplace_back(std::move(smanager));
+ ticketManagers_.emplace_back(std::move(tmanager));
+ return;
+ }
+
+ // Insert by CN
+ insertSSLCtxByDomainName(cn->c_str(), cn->length(), sslCtx);
+
+ // Insert by subject alternative name(s)
+ auto altNames = SSLUtil::getSubjectAltName(x509);
+ if (altNames) {
+ for (auto& name : *altNames) {
+ insertSSLCtxByDomainName(name.c_str(), name.length(), sslCtx);
+ }
+ }
+
+ ctxs_.emplace_back(sslCtx);
+ sessionCacheManagers_.emplace_back(std::move(smanager));
+ ticketManagers_.emplace_back(std::move(tmanager));
+}
+
+void
+SSLContextManager::insertSSLCtxByDomainName(const char* dn, size_t len,
+ shared_ptr<SSLContext> sslCtx) {
+ try {
+ insertSSLCtxByDomainNameImpl(dn, len, sslCtx);
+ } catch (const std::runtime_error& ex) {
+ if (strict_) {
+ throw ex;
+ } else {
+ LOG(ERROR) << ex.what() << " DN=" << dn;
+ }
+ }
+}
+void
+SSLContextManager::insertSSLCtxByDomainNameImpl(const char* dn, size_t len,
+ shared_ptr<SSLContext> sslCtx)
+{
+ VLOG(4) <<
+ folly::stringPrintf("Adding CN/Subject-alternative-name \"%s\" for "
+ "SNI search", dn);
+
+ // Only support wildcard domains which are prefixed exactly by "*." .
+ // "*" appearing at other locations is not accepted.
+
+ if (len > 2 && dn[0] == '*') {
+ if (dn[1] == '.') {
+ // skip the first '*'
+ dn++;
+ len--;
+ } else {
+ throw std::runtime_error(
+ "Invalid wildcard CN/subject-alternative-name \"" + std::string(dn) + "\" "
+ "(only allow character \".\" after \"*\"");
+ }
+ }
+
+ if (len == 1 && *dn == '.') {
+ throw std::runtime_error("X509 has only '.' in the CN or subject alternative name "
+ "(after removing any preceding '*')");
+ }
+
+ if (strchr(dn, '*')) {
+ throw std::runtime_error("X509 has '*' in the the CN or subject alternative name "
+ "(after removing any preceding '*')");
+ }
+
+ DNString dnstr(dn, len);
+ const auto v = dnMap_.find(dnstr);
+ if (v == dnMap_.end()) {
+ dnMap_.emplace(dnstr, sslCtx);
+ } else if (v->second == sslCtx) {
+ VLOG(6)<< "Duplicate CN or subject alternative name found in the same X509."
+ " Ignore the later name.";
+ } else {
+ throw std::runtime_error("Duplicate CN or subject alternative name found: \"" +
+ std::string(dnstr.c_str()) + "\"");
+ }
+}
+
+shared_ptr<SSLContext>
+SSLContextManager::getSSLCtxBySuffix(const DNString& dnstr) const
+{
+ size_t dot;
+
+ if ((dot = dnstr.find_first_of(".")) != DNString::npos) {
+ DNString suffixDNStr(dnstr, dot);
+ const auto v = dnMap_.find(suffixDNStr);
+ if (v != dnMap_.end()) {
+ VLOG(6) << folly::stringPrintf("\"%s\" is a willcard match to \"%s\"",
+ dnstr.c_str(), suffixDNStr.c_str());
+ return v->second;
+ }
+ }
+
+ VLOG(6) << folly::stringPrintf("\"%s\" is not a wildcard match",
+ dnstr.c_str());
+ return shared_ptr<SSLContext>();
+}
+
+shared_ptr<SSLContext>
+SSLContextManager::getSSLCtx(const DNString& dnstr) const
+{
+ const auto v = dnMap_.find(dnstr);
+ if (v == dnMap_.end()) {
+ VLOG(6) << folly::stringPrintf("\"%s\" is not an exact match",
+ dnstr.c_str());
+ return shared_ptr<SSLContext>();
+ } else {
+ VLOG(6) << folly::stringPrintf("\"%s\" is an exact match", dnstr.c_str());
+ return v->second;
+ }
+}
+
+shared_ptr<SSLContext>
+SSLContextManager::getDefaultSSLCtx() const {
+ return defaultCtx_;
+}
+
+void
+SSLContextManager::reloadTLSTicketKeys(
+ const std::vector<std::string>& oldSeeds,
+ const std::vector<std::string>& currentSeeds,
+ const std::vector<std::string>& newSeeds) {
+#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB
+ for (auto& tmgr: ticketManagers_) {
+ tmgr->setTLSTicketKeySeeds(oldSeeds, currentSeeds, newSeeds);
+ }
+#endif
+}
+
+} // namespace
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/io/async/EventBase.h>
+#include <folly/io/async/SSLContext.h>
+
+#include <glog/logging.h>
+#include <list>
+#include <memory>
+#include <folly/wangle/ssl/SSLContextConfig.h>
+#include <folly/wangle/ssl/SSLSessionCacheManager.h>
+#include <folly/wangle/ssl/TLSTicketKeySeeds.h>
+#include <folly/wangle/acceptor/DomainNameMisc.h>
+#include <vector>
+
+namespace folly {
+
+class SocketAddress;
+class SSLContext;
+class ClientHelloExtStats;
+class SSLCacheOptions;
+class SSLStats;
+class TLSTicketKeyManager;
+class TLSTicketKeySeeds;
+
+class SSLContextManager {
+ public:
+
+ explicit SSLContextManager(EventBase* eventBase,
+ const std::string& vipName, bool strict,
+ SSLStats* stats);
+ virtual ~SSLContextManager();
+
+ /**
+ * Add a new X509 to SSLContextManager. The details of a X509
+ * is passed as a SSLContextConfig object.
+ *
+ * @param ctxConfig Details of a X509, its private key, password, etc.
+ * @param cacheOptions Options for how to do session caching.
+ * @param ticketSeeds If non-null, the initial ticket key seeds to use.
+ * @param vipAddress Which VIP are the X509(s) used for? It is only for
+ * for user friendly log message
+ * @param externalCache Optional external provider for the session cache;
+ * may be null
+ */
+ void addSSLContextConfig(
+ const SSLContextConfig& ctxConfig,
+ const SSLCacheOptions& cacheOptions,
+ const TLSTicketKeySeeds* ticketSeeds,
+ const folly::SocketAddress& vipAddress,
+ const std::shared_ptr<SSLCacheProvider> &externalCache);
+
+ /**
+ * Get the default SSL_CTX for a VIP
+ */
+ std::shared_ptr<SSLContext>
+ getDefaultSSLCtx() const;
+
+ /**
+ * Search by the _one_ level up subdomain
+ */
+ std::shared_ptr<SSLContext>
+ getSSLCtxBySuffix(const DNString& dnstr) const;
+
+ /**
+ * Search by the full-string domain name
+ */
+ std::shared_ptr<SSLContext>
+ getSSLCtx(const DNString& dnstr) const;
+
+ /**
+ * Insert a SSLContext by domain name.
+ */
+ void insertSSLCtxByDomainName(
+ const char* dn,
+ size_t len,
+ std::shared_ptr<SSLContext> sslCtx);
+
+ void insertSSLCtxByDomainNameImpl(
+ const char* dn,
+ size_t len,
+ std::shared_ptr<SSLContext> sslCtx);
+
+ void reloadTLSTicketKeys(const std::vector<std::string>& oldSeeds,
+ const std::vector<std::string>& currentSeeds,
+ const std::vector<std::string>& newSeeds);
+
+ /**
+ * SSLContextManager only collects SNI stats now
+ */
+
+ void setClientHelloExtStats(ClientHelloExtStats* stats) {
+ clientHelloTLSExtStats_ = stats;
+ }
+
+ protected:
+ virtual void enableAsyncCrypto(
+ const std::shared_ptr<SSLContext>& sslCtx) {
+ LOG(FATAL) << "Unsupported in base SSLContextManager";
+ }
+ SSLStats* stats_{nullptr};
+
+ private:
+ SSLContextManager(const SSLContextManager&) = delete;
+
+ void ctxSetupByOpensslFeature(
+ std::shared_ptr<SSLContext> sslCtx,
+ const SSLContextConfig& ctxConfig);
+
+ /**
+ * Callback function from openssl to find the right X509 to
+ * use during SSL handshake
+ */
+#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && \
+ !defined(OPENSSL_NO_TLSEXT) && \
+ defined(SSL_CTRL_SET_TLSEXT_SERVERNAME_CB)
+# define PROXYGEN_HAVE_SERVERNAMECALLBACK
+ SSLContext::ServerNameCallbackResult
+ serverNameCallback(SSL* ssl);
+#endif
+
+ /**
+ * The following functions help to maintain the data structure for
+ * domain name matching in SNI. Some notes:
+ *
+ * 1. It is a best match.
+ *
+ * 2. It allows wildcard CN and wildcard subject alternative name in a X509.
+ * The wildcard name must be _prefixed_ by '*.'. It errors out whenever
+ * it sees '*' in any other locations.
+ *
+ * 3. It uses one std::unordered_map<DomainName, SSL_CTX> object to
+ * do this. For wildcard name like "*.facebook.com", ".facebook.com"
+ * is used as the key.
+ *
+ * 4. After getting tlsext_hostname from the client hello message, it
+ * will do a full string search first and then try one level up to
+ * match any wildcard name (if any) in the X509.
+ * [Note, browser also only looks one level up when matching the requesting
+ * domain name with the wildcard name in the server X509].
+ */
+
+ void insert(
+ std::shared_ptr<SSLContext> sslCtx,
+ std::unique_ptr<SSLSessionCacheManager> cmanager,
+ std::unique_ptr<TLSTicketKeyManager> tManager,
+ bool defaultFallback);
+
+ /**
+ * Container to own the SSLContext, SSLSessionCacheManager and
+ * TLSTicketKeyManager.
+ */
+ std::vector<std::shared_ptr<SSLContext>> ctxs_;
+ std::vector<std::unique_ptr<SSLSessionCacheManager>>
+ sessionCacheManagers_;
+ std::vector<std::unique_ptr<TLSTicketKeyManager>> ticketManagers_;
+
+ std::shared_ptr<SSLContext> defaultCtx_;
+
+ /**
+ * Container to store the (DomainName -> SSL_CTX) mapping
+ */
+ std::unordered_map<
+ DNString,
+ std::shared_ptr<SSLContext>,
+ DNStringHash> dnMap_;
+
+ EventBase* eventBase_;
+ ClientHelloExtStats* clientHelloTLSExtStats_{nullptr};
+ SSLContextConfig::SNINoMatchFn noMatchFn_;
+ bool strict_{true};
+};
+
+} // namespace
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/wangle/ssl/SSLSessionCacheManager.h>
+
+#include <folly/wangle/ssl/SSLCacheProvider.h>
+#include <folly/wangle/ssl/SSLStats.h>
+#include <folly/wangle/ssl/SSLUtil.h>
+
+#include <folly/io/async/EventBase.h>
+
+#ifndef NO_LIB_GFLAGS
+#include <gflags/gflags.h>
+#endif
+
+using std::string;
+using std::shared_ptr;
+
+namespace {
+
+const uint32_t NUM_CACHE_BUCKETS = 16;
+
+// We use the default ID generator which fills the maximum ID length
+// for the protocol. 16 bytes for SSLv2 or 32 for SSLv3+
+const int MIN_SESSION_ID_LENGTH = 16;
+
+}
+
+#ifndef NO_LIB_GFLAGS
+DEFINE_bool(dcache_unit_test, false, "All VIPs share one session cache");
+#else
+const bool FLAGS_dcache_unit_test = false;
+#endif
+
+namespace folly {
+
+
+int SSLSessionCacheManager::sExDataIndex_ = -1;
+shared_ptr<ShardedLocalSSLSessionCache> SSLSessionCacheManager::sCache_;
+std::mutex SSLSessionCacheManager::sCacheLock_;
+
+LocalSSLSessionCache::LocalSSLSessionCache(uint32_t maxCacheSize,
+ uint32_t cacheCullSize)
+ : sessionCache(maxCacheSize, cacheCullSize) {
+ sessionCache.setPruneHook(std::bind(
+ &LocalSSLSessionCache::pruneSessionCallback,
+ this, std::placeholders::_1,
+ std::placeholders::_2));
+}
+
+void LocalSSLSessionCache::pruneSessionCallback(const string& sessionId,
+ SSL_SESSION* session) {
+ VLOG(4) << "Free SSL session from local cache; id="
+ << SSLUtil::hexlify(sessionId);
+ SSL_SESSION_free(session);
+ ++removedSessions_;
+}
+
+
+// SSLSessionCacheManager implementation
+
+SSLSessionCacheManager::SSLSessionCacheManager(
+ uint32_t maxCacheSize,
+ uint32_t cacheCullSize,
+ SSLContext* ctx,
+ const folly::SocketAddress& sockaddr,
+ const string& context,
+ EventBase* eventBase,
+ SSLStats* stats,
+ const std::shared_ptr<SSLCacheProvider>& externalCache):
+ ctx_(ctx),
+ stats_(stats),
+ externalCache_(externalCache) {
+
+ SSL_CTX* sslCtx = ctx->getSSLCtx();
+
+ SSLUtil::getSSLCtxExIndex(&sExDataIndex_);
+
+ SSL_CTX_set_ex_data(sslCtx, sExDataIndex_, this);
+ SSL_CTX_sess_set_new_cb(sslCtx, SSLSessionCacheManager::newSessionCallback);
+ SSL_CTX_sess_set_get_cb(sslCtx, SSLSessionCacheManager::getSessionCallback);
+ SSL_CTX_sess_set_remove_cb(sslCtx,
+ SSLSessionCacheManager::removeSessionCallback);
+ if (!FLAGS_dcache_unit_test && !context.empty()) {
+ // Use the passed in context
+ SSL_CTX_set_session_id_context(sslCtx, (const uint8_t *)context.data(),
+ std::min((int)context.length(),
+ SSL_MAX_SSL_SESSION_ID_LENGTH));
+ }
+
+ SSL_CTX_set_session_cache_mode(sslCtx, SSL_SESS_CACHE_NO_INTERNAL
+ | SSL_SESS_CACHE_SERVER);
+
+ localCache_ = SSLSessionCacheManager::getLocalCache(maxCacheSize,
+ cacheCullSize);
+
+ VLOG(2) << "On VipID=" << sockaddr.describe() << " context=" << context;
+}
+
+SSLSessionCacheManager::~SSLSessionCacheManager() {
+}
+
+void SSLSessionCacheManager::shutdown() {
+ std::lock_guard<std::mutex> g(sCacheLock_);
+ sCache_.reset();
+}
+
+shared_ptr<ShardedLocalSSLSessionCache> SSLSessionCacheManager::getLocalCache(
+ uint32_t maxCacheSize,
+ uint32_t cacheCullSize) {
+
+ std::lock_guard<std::mutex> g(sCacheLock_);
+ if (!sCache_) {
+ sCache_.reset(new ShardedLocalSSLSessionCache(NUM_CACHE_BUCKETS,
+ maxCacheSize,
+ cacheCullSize));
+ }
+ return sCache_;
+}
+
+int SSLSessionCacheManager::newSessionCallback(SSL* ssl, SSL_SESSION* session) {
+ SSLSessionCacheManager* manager = nullptr;
+ SSL_CTX* ctx = SSL_get_SSL_CTX(ssl);
+ manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
+
+ if (manager == nullptr) {
+ LOG(FATAL) << "Null SSLSessionCacheManager in callback";
+ return -1;
+ }
+ return manager->newSession(ssl, session);
+}
+
+
+int SSLSessionCacheManager::newSession(SSL* ssl, SSL_SESSION* session) {
+ string sessionId((char*)session->session_id, session->session_id_length);
+ VLOG(4) << "New SSL session; id=" << SSLUtil::hexlify(sessionId);
+
+ if (stats_) {
+ stats_->recordSSLSession(true /* new session */, false, false);
+ }
+
+ localCache_->storeSession(sessionId, session, stats_);
+
+ if (externalCache_) {
+ VLOG(4) << "New SSL session: send session to external cache; id=" <<
+ SSLUtil::hexlify(sessionId);
+ storeCacheRecord(sessionId, session);
+ }
+
+ return 1;
+}
+
+void SSLSessionCacheManager::removeSessionCallback(SSL_CTX* ctx,
+ SSL_SESSION* session) {
+ SSLSessionCacheManager* manager = nullptr;
+ manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
+
+ if (manager == nullptr) {
+ LOG(FATAL) << "Null SSLSessionCacheManager in callback";
+ return;
+ }
+ return manager->removeSession(ctx, session);
+}
+
+void SSLSessionCacheManager::removeSession(SSL_CTX* ctx,
+ SSL_SESSION* session) {
+ string sessionId((char*)session->session_id, session->session_id_length);
+
+ // This hook is only called from SSL when the internal session cache needs to
+ // flush sessions. Since we run with the internal cache disabled, this should
+ // never be called
+ VLOG(3) << "Remove SSL session; id=" << SSLUtil::hexlify(sessionId);
+
+ localCache_->removeSession(sessionId);
+
+ if (stats_) {
+ stats_->recordSSLSessionRemove();
+ }
+}
+
+SSL_SESSION* SSLSessionCacheManager::getSessionCallback(SSL* ssl,
+ unsigned char* sess_id,
+ int id_len,
+ int* copyflag) {
+ SSLSessionCacheManager* manager = nullptr;
+ SSL_CTX* ctx = SSL_get_SSL_CTX(ssl);
+ manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
+
+ if (manager == nullptr) {
+ LOG(FATAL) << "Null SSLSessionCacheManager in callback";
+ return nullptr;
+ }
+ return manager->getSession(ssl, sess_id, id_len, copyflag);
+}
+
+SSL_SESSION* SSLSessionCacheManager::getSession(SSL* ssl,
+ unsigned char* session_id,
+ int id_len,
+ int* copyflag) {
+ VLOG(7) << "SSL get session callback";
+ SSL_SESSION* session = nullptr;
+ bool foreign = false;
+ char const* missReason = nullptr;
+
+ if (id_len < MIN_SESSION_ID_LENGTH) {
+ // We didn't generate this session so it's going to be a miss.
+ // This doesn't get logged or counted in the stats.
+ return nullptr;
+ }
+ string sessionId((char*)session_id, id_len);
+
+ AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl);
+
+ assert(sslSocket != nullptr);
+
+ // look it up in the local cache first
+ session = localCache_->lookupSession(sessionId);
+#ifdef SSL_SESSION_CB_WOULD_BLOCK
+ if (session == nullptr && externalCache_) {
+ // external cache might have the session
+ foreign = true;
+ if (!SSL_want_sess_cache_lookup(ssl)) {
+ missReason = "reason: No async cache support;";
+ } else {
+ PendingLookupMap::iterator pit = pendingLookups_.find(sessionId);
+ if (pit == pendingLookups_.end()) {
+ auto result = pendingLookups_.emplace(sessionId, PendingLookup());
+ // initiate fetch
+ VLOG(4) << "Get SSL session [Pending]: Initiate Fetch; fd=" <<
+ sslSocket->getFd() << " id=" << SSLUtil::hexlify(sessionId);
+ if (lookupCacheRecord(sessionId, sslSocket)) {
+ // response is pending
+ *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
+ return nullptr;
+ } else {
+ missReason = "reason: failed to send lookup request;";
+ pendingLookups_.erase(result.first);
+ }
+ } else {
+ // A lookup was already initiated from this thread
+ if (pit->second.request_in_progress) {
+ // Someone else initiated the request, attach
+ VLOG(4) << "Get SSL session [Pending]: Request in progess: attach; "
+ "fd=" << sslSocket->getFd() << " id=" <<
+ SSLUtil::hexlify(sessionId);
+ std::unique_ptr<DelayedDestruction::DestructorGuard> dg(
+ new DelayedDestruction::DestructorGuard(sslSocket));
+ pit->second.waiters.push_back(
+ std::make_pair(sslSocket, std::move(dg)));
+ *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
+ return nullptr;
+ }
+ // request is complete
+ session = pit->second.session; // nullptr if our friend didn't have it
+ if (session != nullptr) {
+ CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
+ }
+ }
+ }
+ }
+#endif
+
+ bool hit = (session != nullptr);
+ if (stats_) {
+ stats_->recordSSLSession(false, hit, foreign);
+ }
+ if (hit) {
+ sslSocket->setSessionIDResumed(true);
+ }
+
+ VLOG(4) << "Get SSL session [" <<
+ ((hit) ? "Hit" : "Miss") << "]: " <<
+ ((foreign) ? "external" : "local") << " cache; " <<
+ ((missReason != nullptr) ? missReason : "") << "fd=" <<
+ sslSocket->getFd() << " id=" << SSLUtil::hexlify(sessionId);
+
+ // We already bumped the refcount
+ *copyflag = 0;
+
+ return session;
+}
+
+bool SSLSessionCacheManager::storeCacheRecord(const string& sessionId,
+ SSL_SESSION* session) {
+ std::string sessionString;
+ uint32_t sessionLen = i2d_SSL_SESSION(session, nullptr);
+ sessionString.resize(sessionLen);
+ uint8_t* cp = (uint8_t *)sessionString.data();
+ i2d_SSL_SESSION(session, &cp);
+ size_t expiration = SSL_CTX_get_timeout(ctx_->getSSLCtx());
+ return externalCache_->setAsync(sessionId, sessionString,
+ std::chrono::seconds(expiration));
+}
+
+bool SSLSessionCacheManager::lookupCacheRecord(const string& sessionId,
+ AsyncSSLSocket* sslSocket) {
+ auto cacheCtx = new SSLCacheProvider::CacheContext();
+ cacheCtx->sessionId = sessionId;
+ cacheCtx->session = nullptr;
+ cacheCtx->sslSocket = sslSocket;
+ cacheCtx->guard.reset(
+ new DelayedDestruction::DestructorGuard(cacheCtx->sslSocket));
+ cacheCtx->manager = this;
+ bool res = externalCache_->getAsync(sessionId, cacheCtx);
+ if (!res) {
+ delete cacheCtx;
+ }
+ return res;
+}
+
+void SSLSessionCacheManager::restartSSLAccept(
+ const SSLCacheProvider::CacheContext* cacheCtx) {
+ PendingLookupMap::iterator pit = pendingLookups_.find(cacheCtx->sessionId);
+ CHECK(pit != pendingLookups_.end());
+ pit->second.request_in_progress = false;
+ pit->second.session = cacheCtx->session;
+ VLOG(7) << "Restart SSL accept";
+ cacheCtx->sslSocket->restartSSLAccept();
+ for (const auto& attachedLookup: pit->second.waiters) {
+ // Wake up anyone else who was waiting for this session
+ VLOG(4) << "Restart SSL accept (waiters) for fd=" <<
+ attachedLookup.first->getFd();
+ attachedLookup.first->restartSSLAccept();
+ }
+ pendingLookups_.erase(pit);
+}
+
+void SSLSessionCacheManager::onGetSuccess(
+ SSLCacheProvider::CacheContext* cacheCtx,
+ const std::string& value) {
+ const uint8_t* cp = (uint8_t*)value.data();
+ cacheCtx->session = d2i_SSL_SESSION(nullptr, &cp, value.length());
+ restartSSLAccept(cacheCtx);
+
+ /* Insert in the LRU after restarting all clients. The stats logic
+ * in getSession would treat this as a local hit otherwise.
+ */
+ localCache_->storeSession(cacheCtx->sessionId, cacheCtx->session, stats_);
+ delete cacheCtx;
+}
+
+void SSLSessionCacheManager::onGetFailure(
+ SSLCacheProvider::CacheContext* cacheCtx) {
+ restartSSLAccept(cacheCtx);
+ delete cacheCtx;
+}
+
+} // namespace
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/wangle/ssl/SSLCacheProvider.h>
+#include <folly/wangle/ssl/SSLStats.h>
+
+#include <folly/EvictingCacheMap.h>
+#include <mutex>
+#include <folly/io/async/AsyncSSLSocket.h>
+
+namespace folly {
+
+class SSLStats;
+
+/**
+ * Basic SSL session cache map: Maps session id -> session
+ */
+typedef folly::EvictingCacheMap<std::string, SSL_SESSION*> SSLSessionCacheMap;
+
+/**
+ * Holds an SSLSessionCacheMap and associated lock
+ */
+class LocalSSLSessionCache: private boost::noncopyable {
+ public:
+ LocalSSLSessionCache(uint32_t maxCacheSize, uint32_t cacheCullSize);
+
+ ~LocalSSLSessionCache() {
+ std::lock_guard<std::mutex> g(lock);
+ // EvictingCacheMap dtor doesn't free values
+ sessionCache.clear();
+ }
+
+ SSLSessionCacheMap sessionCache;
+ std::mutex lock;
+ uint32_t removedSessions_{0};
+
+ private:
+
+ void pruneSessionCallback(const std::string& sessionId,
+ SSL_SESSION* session);
+};
+
+/**
+ * A sharded LRU for SSL sessions. The sharding is inteneded to reduce
+ * contention for the LRU locks. Assuming uniform distribution, two workers
+ * will contend for the same lock with probability 1 / n_buckets^2.
+ */
+class ShardedLocalSSLSessionCache : private boost::noncopyable {
+ public:
+ ShardedLocalSSLSessionCache(uint32_t n_buckets, uint32_t maxCacheSize,
+ uint32_t cacheCullSize) {
+ CHECK(n_buckets > 0);
+ maxCacheSize = (uint32_t)(((double)maxCacheSize) / n_buckets);
+ cacheCullSize = (uint32_t)(((double)cacheCullSize) / n_buckets);
+ if (maxCacheSize == 0) {
+ maxCacheSize = 1;
+ }
+ if (cacheCullSize == 0) {
+ cacheCullSize = 1;
+ }
+ for (uint32_t i = 0; i < n_buckets; i++) {
+ caches_.push_back(
+ std::unique_ptr<LocalSSLSessionCache>(
+ new LocalSSLSessionCache(maxCacheSize, cacheCullSize)));
+ }
+ }
+
+ SSL_SESSION* lookupSession(const std::string& sessionId) {
+ size_t bucket = hash(sessionId);
+ SSL_SESSION* session = nullptr;
+ std::lock_guard<std::mutex> g(caches_[bucket]->lock);
+
+ auto itr = caches_[bucket]->sessionCache.find(sessionId);
+ if (itr != caches_[bucket]->sessionCache.end()) {
+ session = itr->second;
+ }
+
+ if (session) {
+ CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
+ }
+ return session;
+ }
+
+ void storeSession(const std::string& sessionId, SSL_SESSION* session,
+ SSLStats* stats) {
+ size_t bucket = hash(sessionId);
+ SSL_SESSION* oldSession = nullptr;
+ std::lock_guard<std::mutex> g(caches_[bucket]->lock);
+
+ auto itr = caches_[bucket]->sessionCache.find(sessionId);
+ if (itr != caches_[bucket]->sessionCache.end()) {
+ oldSession = itr->second;
+ }
+
+ if (oldSession) {
+ // LRUCacheMap doesn't free on overwrite, so 2x the work for us
+ // This can happen in race conditions
+ SSL_SESSION_free(oldSession);
+ }
+ caches_[bucket]->removedSessions_ = 0;
+ caches_[bucket]->sessionCache.set(sessionId, session, true);
+ if (stats) {
+ stats->recordSSLSessionFree(caches_[bucket]->removedSessions_);
+ }
+ }
+
+ void removeSession(const std::string& sessionId) {
+ size_t bucket = hash(sessionId);
+ std::lock_guard<std::mutex> g(caches_[bucket]->lock);
+ caches_[bucket]->sessionCache.erase(sessionId);
+ }
+
+ private:
+
+ /* SSL session IDs are 32 bytes of random data, hash based on first 16 bits */
+ size_t hash(const std::string& key) {
+ CHECK(key.length() >= 2);
+ return (key[0] << 8 | key[1]) % caches_.size();
+ }
+
+ std::vector< std::unique_ptr<LocalSSLSessionCache> > caches_;
+};
+
+/* A socket/DestructorGuard pair */
+typedef std::pair<AsyncSSLSocket *,
+ std::unique_ptr<DelayedDestruction::DestructorGuard>>
+ AttachedLookup;
+
+/**
+ * PendingLookup structure
+ *
+ * Keeps track of clients waiting for an SSL session to be retrieved from
+ * the external cache provider.
+ */
+struct PendingLookup {
+ bool request_in_progress;
+ SSL_SESSION* session;
+ std::list<AttachedLookup> waiters;
+
+ PendingLookup() {
+ request_in_progress = true;
+ session = nullptr;
+ }
+};
+
+/* Maps SSL session id to a PendingLookup structure */
+typedef std::map<std::string, PendingLookup> PendingLookupMap;
+
+/**
+ * SSLSessionCacheManager handles all stateful session caching. There is an
+ * instance of this object per SSL VIP per thread, with a 1:1 correlation with
+ * SSL_CTX. The cache can work locally or in concert with an external cache
+ * to share sessions across instances.
+ *
+ * There is a single in memory session cache shared by all VIPs. The cache is
+ * split into N buckets (currently 16) with a separate lock per bucket. The
+ * VIP ID is hashed and stored as part of the session to handle the
+ * (very unlikely) case of session ID collision.
+ *
+ * When a new SSL session is created, it is added to the LRU cache and
+ * sent to the external cache to be stored. The external cache
+ * expiration is equal to the SSL session's expiration.
+ *
+ * When a resume request is received, SSLSessionCacheManager first looks in the
+ * local LRU cache for the VIP. If there is a miss there, an asynchronous
+ * request for this session is dispatched to the external cache. When the
+ * external cache query returns, the LRU cache is updated if the session was
+ * found, and the SSL_accept call is resumed.
+ *
+ * If additional resume requests for the same session ID arrive in the same
+ * thread while the request is pending, the 2nd - Nth callers attach to the
+ * original external cache requests and are resumed when it comes back. No
+ * attempt is made to coalesce external cache requests for the same session
+ * ID in different worker threads. Previous work did this, but the
+ * complexity was deemed to outweigh the potential savings.
+ *
+ */
+class SSLSessionCacheManager : private boost::noncopyable {
+ public:
+ /**
+ * Constructor. SSL session related callbacks will be set on the underlying
+ * SSL_CTX. vipId is assumed to a unique string identifying the VIP and must
+ * be the same on all servers that wish to share sessions via the same
+ * external cache.
+ */
+ SSLSessionCacheManager(
+ uint32_t maxCacheSize,
+ uint32_t cacheCullSize,
+ SSLContext* ctx,
+ const folly::SocketAddress& sockaddr,
+ const std::string& context,
+ EventBase* eventBase,
+ SSLStats* stats,
+ const std::shared_ptr<SSLCacheProvider>& externalCache);
+
+ virtual ~SSLSessionCacheManager();
+
+ /**
+ * Call this on shutdown to release the global instance of the
+ * ShardedLocalSSLSessionCache.
+ */
+ static void shutdown();
+
+ /**
+ * Callback for ExternalCache to call when an async get succeeds
+ * @param context The context that was passed to the async get request
+ * @param value Serialized session
+ */
+ void onGetSuccess(SSLCacheProvider::CacheContext* context,
+ const std::string& value);
+
+ /**
+ * Callback for ExternalCache to call when an async get fails, either
+ * because the requested session is not in the external cache or because
+ * of an error.
+ * @param context The context that was passed to the async get request
+ */
+ void onGetFailure(SSLCacheProvider::CacheContext* context);
+
+ private:
+
+ SSLContext* ctx_;
+ std::shared_ptr<ShardedLocalSSLSessionCache> localCache_;
+ PendingLookupMap pendingLookups_;
+ SSLStats* stats_{nullptr};
+ std::shared_ptr<SSLCacheProvider> externalCache_;
+
+ /**
+ * Invoked by openssl when a new SSL session is created
+ */
+ int newSession(SSL* ssl, SSL_SESSION* session);
+
+ /**
+ * Invoked by openssl when an SSL session is ejected from its internal cache.
+ * This can't be invoked in the current implementation because SSL's internal
+ * caching is disabled.
+ */
+ void removeSession(SSL_CTX* ctx, SSL_SESSION* session);
+
+ /**
+ * Invoked by openssl when a client requests a stateful session resumption.
+ * Triggers a lookup in our local cache and potentially an asynchronous
+ * request to an external cache.
+ */
+ SSL_SESSION* getSession(SSL* ssl, unsigned char* session_id,
+ int id_len, int* copyflag);
+
+ /**
+ * Store a new session record in the external cache
+ */
+ bool storeCacheRecord(const std::string& sessionId, SSL_SESSION* session);
+
+ /**
+ * Lookup a session in the external cache for the specified SSL socket.
+ */
+ bool lookupCacheRecord(const std::string& sessionId,
+ AsyncSSLSocket* sslSock);
+
+ /**
+ * Restart all clients waiting for the answer to an external cache query
+ */
+ void restartSSLAccept(const SSLCacheProvider::CacheContext* cacheCtx);
+
+ /**
+ * Get or create the LRU cache for the given VIP ID
+ */
+ static std::shared_ptr<ShardedLocalSSLSessionCache> getLocalCache(
+ uint32_t maxCacheSize, uint32_t cacheCullSize);
+
+ /**
+ * static functions registered as callbacks to openssl via
+ * SSL_CTX_sess_set_new/get/remove_cb
+ */
+ static int newSessionCallback(SSL* ssl, SSL_SESSION* session);
+ static void removeSessionCallback(SSL_CTX* ctx, SSL_SESSION* session);
+ static SSL_SESSION* getSessionCallback(SSL* ssl, unsigned char* session_id,
+ int id_len, int* copyflag);
+
+ static int32_t sExDataIndex_;
+ static std::shared_ptr<ShardedLocalSSLSessionCache> sCache_;
+ static std::mutex sCacheLock_;
+};
+
+}
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+namespace folly {
+
+class SSLStats {
+ public:
+ virtual ~SSLStats() noexcept {}
+
+ // downstream
+ virtual void recordSSLAcceptLatency(int64_t latency) noexcept = 0;
+ virtual void recordTLSTicket(bool ticketNew, bool ticketHit) noexcept = 0;
+ virtual void recordSSLSession(bool sessionNew, bool sessionHit, bool foreign)
+ noexcept = 0;
+ virtual void recordSSLSessionRemove() noexcept = 0;
+ virtual void recordSSLSessionFree(uint32_t freed) noexcept = 0;
+ virtual void recordSSLSessionSetError(uint32_t err) noexcept = 0;
+ virtual void recordSSLSessionGetError(uint32_t err) noexcept = 0;
+ virtual void recordClientRenegotiation() noexcept = 0;
+
+ // upstream
+ virtual void recordSSLUpstreamConnection(bool handshake) noexcept = 0;
+ virtual void recordSSLUpstreamConnectionError(bool verifyError) noexcept = 0;
+ virtual void recordCryptoSSLExternalAttempt() noexcept = 0;
+ virtual void recordCryptoSSLExternalConnAlreadyClosed() noexcept = 0;
+ virtual void recordCryptoSSLExternalApplicationException() noexcept = 0;
+ virtual void recordCryptoSSLExternalSuccess() noexcept = 0;
+ virtual void recordCryptoSSLExternalDuration(uint64_t duration) noexcept = 0;
+ virtual void recordCryptoSSLLocalAttempt() noexcept = 0;
+ virtual void recordCryptoSSLLocalSuccess() noexcept = 0;
+
+};
+
+}
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/wangle/ssl/SSLUtil.h>
+
+#include <folly/Memory.h>
+
+#if OPENSSL_VERSION_NUMBER >= 0x1000105fL
+#define OPENSSL_GE_101 1
+#include <openssl/asn1.h>
+#include <openssl/x509v3.h>
+#else
+#undef OPENSSL_GE_101
+#endif
+
+namespace folly {
+
+std::mutex SSLUtil::sIndexLock_;
+
+std::unique_ptr<std::string> SSLUtil::getCommonName(const X509* cert) {
+ X509_NAME* subject = X509_get_subject_name((X509*)cert);
+ if (!subject) {
+ return nullptr;
+ }
+ char cn[ub_common_name + 1];
+ int res = X509_NAME_get_text_by_NID(subject, NID_commonName,
+ cn, ub_common_name);
+ if (res <= 0) {
+ return nullptr;
+ } else {
+ cn[ub_common_name] = '\0';
+ return folly::make_unique<std::string>(cn);
+ }
+}
+
+std::unique_ptr<std::list<std::string>> SSLUtil::getSubjectAltName(
+ const X509* cert) {
+#ifdef OPENSSL_GE_101
+ auto nameList = folly::make_unique<std::list<std::string>>();
+ GENERAL_NAMES* names = (GENERAL_NAMES*)X509_get_ext_d2i(
+ (X509*)cert, NID_subject_alt_name, nullptr, nullptr);
+ if (names) {
+ auto guard = folly::makeGuard([names] { GENERAL_NAMES_free(names); });
+ size_t count = sk_GENERAL_NAME_num(names);
+ CHECK(count < std::numeric_limits<int>::max());
+ for (int i = 0; i < (int)count; ++i) {
+ GENERAL_NAME* generalName = sk_GENERAL_NAME_value(names, i);
+ if (generalName->type == GEN_DNS) {
+ ASN1_STRING* s = generalName->d.dNSName;
+ const char* name = (const char*)ASN1_STRING_data(s);
+ // I can't find any docs on what a negative return value here
+ // would mean, so I'm going to ignore it.
+ auto len = ASN1_STRING_length(s);
+ DCHECK(len >= 0);
+ if (size_t(len) != strlen(name)) {
+ // Null byte(s) in the name; return an error rather than depending on
+ // the caller to safely handle this case.
+ return nullptr;
+ }
+ nameList->emplace_back(name);
+ }
+ }
+ }
+ return nameList;
+#else
+ return nullptr;
+#endif
+}
+
+}
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/String.h>
+#include <mutex>
+#include <folly/io/async/AsyncSSLSocket.h>
+
+namespace folly {
+
+/**
+ * SSL session establish/resume status
+ *
+ * changing these values will break logging pipelines
+ */
+enum class SSLResumeEnum : uint8_t {
+ HANDSHAKE = 0,
+ RESUME_SESSION_ID = 1,
+ RESUME_TICKET = 3,
+ NA = 2
+};
+
+enum class SSLErrorEnum {
+ NO_ERROR,
+ TIMEOUT,
+ DROPPED
+};
+
+class SSLUtil {
+ private:
+ static std::mutex sIndexLock_;
+
+ public:
+ /**
+ * Ensures only one caller will allocate an ex_data index for a given static
+ * or global.
+ */
+ static void getSSLCtxExIndex(int* pindex) {
+ std::lock_guard<std::mutex> g(sIndexLock_);
+ if (*pindex < 0) {
+ *pindex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
+ }
+ }
+
+ static void getRSAExIndex(int* pindex) {
+ std::lock_guard<std::mutex> g(sIndexLock_);
+ if (*pindex < 0) {
+ *pindex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
+ }
+ }
+
+ static inline std::string hexlify(const std::string& binary) {
+ std::string hex;
+ folly::hexlify<std::string, std::string>(binary, hex);
+
+ return hex;
+ }
+
+ static inline const std::string& hexlify(const std::string& binary,
+ std::string& hex) {
+ folly::hexlify<std::string, std::string>(binary, hex);
+
+ return hex;
+ }
+
+ /**
+ * Return the SSL resume type for the given socket.
+ */
+ static inline SSLResumeEnum getResumeState(
+ AsyncSSLSocket* sslSocket) {
+ return sslSocket->getSSLSessionReused() ?
+ (sslSocket->sessionIDResumed() ?
+ SSLResumeEnum::RESUME_SESSION_ID :
+ SSLResumeEnum::RESUME_TICKET) :
+ SSLResumeEnum::HANDSHAKE;
+ }
+
+ /**
+ * Get the Common Name from an X.509 certificate
+ * @param cert certificate to inspect
+ * @return common name, or null if an error occurs
+ */
+ static std::unique_ptr<std::string> getCommonName(const X509* cert);
+
+ /**
+ * Get the Subject Alternative Name value(s) from an X.509 certificate
+ * @param cert certificate to inspect
+ * @return set of zero or more alternative names, or null if
+ * an error occurs
+ */
+ static std::unique_ptr<std::list<std::string>> getSubjectAltName(
+ const X509* cert);
+};
+
+}
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/wangle/ssl/TLSTicketKeyManager.h>
+
+#include <folly/wangle/ssl/SSLStats.h>
+#include <folly/wangle/ssl/SSLUtil.h>
+
+#include <folly/String.h>
+#include <openssl/aes.h>
+#include <openssl/rand.h>
+#include <openssl/ssl.h>
+#include <folly/io/async/AsyncTimeout.h>
+
+#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB
+using std::string;
+
+namespace {
+
+const int kTLSTicketKeyNameLen = 4;
+const int kTLSTicketKeySaltLen = 12;
+
+}
+
+namespace folly {
+
+
+// TLSTicketKeyManager Implementation
+int32_t TLSTicketKeyManager::sExDataIndex_ = -1;
+
+TLSTicketKeyManager::TLSTicketKeyManager(SSLContext* ctx, SSLStats* stats)
+ : ctx_(ctx),
+ randState_(0),
+ stats_(stats) {
+ SSLUtil::getSSLCtxExIndex(&sExDataIndex_);
+ SSL_CTX_set_ex_data(ctx_->getSSLCtx(), sExDataIndex_, this);
+}
+
+TLSTicketKeyManager::~TLSTicketKeyManager() {
+}
+
+int
+TLSTicketKeyManager::callback(SSL* ssl, unsigned char* keyName,
+ unsigned char* iv,
+ EVP_CIPHER_CTX* cipherCtx,
+ HMAC_CTX* hmacCtx, int encrypt) {
+ TLSTicketKeyManager* manager = nullptr;
+ SSL_CTX* ctx = SSL_get_SSL_CTX(ssl);
+ manager = (TLSTicketKeyManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
+
+ if (manager == nullptr) {
+ LOG(FATAL) << "Null TLSTicketKeyManager in callback" ;
+ return -1;
+ }
+ return manager->processTicket(ssl, keyName, iv, cipherCtx, hmacCtx, encrypt);
+}
+
+int
+TLSTicketKeyManager::processTicket(SSL* ssl, unsigned char* keyName,
+ unsigned char* iv,
+ EVP_CIPHER_CTX* cipherCtx,
+ HMAC_CTX* hmacCtx, int encrypt) {
+ uint8_t salt[kTLSTicketKeySaltLen];
+ uint8_t* saltptr = nullptr;
+ uint8_t output[SHA256_DIGEST_LENGTH];
+ uint8_t* hmacKey = nullptr;
+ uint8_t* aesKey = nullptr;
+ TLSTicketKeySource* key = nullptr;
+ int result = 0;
+
+ if (encrypt) {
+ key = findEncryptionKey();
+ if (key == nullptr) {
+ // no keys available to encrypt
+ VLOG(2) << "No TLS ticket key found";
+ return -1;
+ }
+ VLOG(4) << "Encrypting new ticket with key name=" <<
+ SSLUtil::hexlify(key->keyName_);
+
+ // Get a random salt and write out key name
+ RAND_pseudo_bytes(salt, (int)sizeof(salt));
+ memcpy(keyName, key->keyName_.data(), kTLSTicketKeyNameLen);
+ memcpy(keyName + kTLSTicketKeyNameLen, salt, kTLSTicketKeySaltLen);
+
+ // Create the unique keys by hashing with the salt
+ makeUniqueKeys(key->keySource_, sizeof(key->keySource_), salt, output);
+ // This relies on the fact that SHA256 has 32 bytes of output
+ // and that AES-128 keys are 16 bytes
+ hmacKey = output;
+ aesKey = output + SHA256_DIGEST_LENGTH / 2;
+
+ // Initialize iv and cipher/mac CTX
+ RAND_pseudo_bytes(iv, AES_BLOCK_SIZE);
+ HMAC_Init_ex(hmacCtx, hmacKey, SHA256_DIGEST_LENGTH / 2,
+ EVP_sha256(), nullptr);
+ EVP_EncryptInit_ex(cipherCtx, EVP_aes_128_cbc(), nullptr, aesKey, iv);
+
+ result = 1;
+ } else {
+ key = findDecryptionKey(keyName);
+ if (key == nullptr) {
+ // no ticket found for decryption - will issue a new ticket
+ if (VLOG_IS_ON(4)) {
+ string skeyName((char *)keyName, kTLSTicketKeyNameLen);
+ VLOG(4) << "Can't find ticket key with name=" <<
+ SSLUtil::hexlify(skeyName)<< ", will generate new ticket";
+ }
+
+ result = 0;
+ } else {
+ VLOG(4) << "Decrypting ticket with key name=" <<
+ SSLUtil::hexlify(key->keyName_);
+
+ // Reconstruct the unique key via the salt
+ saltptr = keyName + kTLSTicketKeyNameLen;
+ makeUniqueKeys(key->keySource_, sizeof(key->keySource_), saltptr, output);
+ hmacKey = output;
+ aesKey = output + SHA256_DIGEST_LENGTH / 2;
+
+ // Initialize cipher/mac CTX
+ HMAC_Init_ex(hmacCtx, hmacKey, SHA256_DIGEST_LENGTH / 2,
+ EVP_sha256(), nullptr);
+ EVP_DecryptInit_ex(cipherCtx, EVP_aes_128_cbc(), nullptr, aesKey, iv);
+
+ result = 1;
+ }
+ }
+ // result records whether a ticket key was found to decrypt this ticket,
+ // not wether the session was re-used.
+ if (stats_) {
+ stats_->recordTLSTicket(encrypt, result);
+ }
+
+ return result;
+}
+
+bool
+TLSTicketKeyManager::setTLSTicketKeySeeds(
+ const std::vector<std::string>& oldSeeds,
+ const std::vector<std::string>& currentSeeds,
+ const std::vector<std::string>& newSeeds) {
+
+ bool result = true;
+
+ activeKeys_.clear();
+ ticketKeys_.clear();
+ ticketSeeds_.clear();
+ const std::vector<string> *seedList = &oldSeeds;
+ for (uint32_t i = 0; i < 3; i++) {
+ TLSTicketSeedType type = (TLSTicketSeedType)i;
+ if (type == SEED_CURRENT) {
+ seedList = ¤tSeeds;
+ } else if (type == SEED_NEW) {
+ seedList = &newSeeds;
+ }
+
+ for (const auto& seedInput: *seedList) {
+ TLSTicketSeed* seed = insertSeed(seedInput, type);
+ if (seed == nullptr) {
+ result = false;
+ continue;
+ }
+ insertNewKey(seed, 1, nullptr);
+ }
+ }
+ if (!result) {
+ VLOG(2) << "One or more seeds failed to decode";
+ }
+
+ if (ticketKeys_.size() == 0 || activeKeys_.size() == 0) {
+ LOG(WARNING) << "No keys configured, falling back to default";
+ SSL_CTX_set_tlsext_ticket_key_cb(ctx_->getSSLCtx(), nullptr);
+ return false;
+ }
+ SSL_CTX_set_tlsext_ticket_key_cb(ctx_->getSSLCtx(),
+ TLSTicketKeyManager::callback);
+
+ return true;
+}
+
+string
+TLSTicketKeyManager::makeKeyName(TLSTicketSeed* seed, uint32_t n,
+ unsigned char* nameBuf) {
+ SHA256_CTX ctx;
+
+ SHA256_Init(&ctx);
+ SHA256_Update(&ctx, seed->seedName_, sizeof(seed->seedName_));
+ SHA256_Update(&ctx, &n, sizeof(n));
+ SHA256_Final(nameBuf, &ctx);
+ return string((char *)nameBuf, kTLSTicketKeyNameLen);
+}
+
+TLSTicketKeyManager::TLSTicketKeySource*
+TLSTicketKeyManager::insertNewKey(TLSTicketSeed* seed, uint32_t hashCount,
+ TLSTicketKeySource* prevKey) {
+ unsigned char nameBuf[SHA256_DIGEST_LENGTH];
+ std::unique_ptr<TLSTicketKeySource> newKey(new TLSTicketKeySource);
+
+ // This function supports hash chaining but it is not currently used.
+
+ if (prevKey != nullptr) {
+ hashNth(prevKey->keySource_, sizeof(prevKey->keySource_),
+ newKey->keySource_, 1);
+ } else {
+ // can't go backwards or the current is missing, start from the beginning
+ hashNth((unsigned char *)seed->seed_.data(), seed->seed_.length(),
+ newKey->keySource_, hashCount);
+ }
+
+ newKey->hashCount_ = hashCount;
+ newKey->keyName_ = makeKeyName(seed, hashCount, nameBuf);
+ newKey->type_ = seed->type_;
+ auto it = ticketKeys_.insert(std::make_pair(newKey->keyName_,
+ std::move(newKey)));
+
+ auto key = it.first->second.get();
+ if (key->type_ == SEED_CURRENT) {
+ activeKeys_.push_back(key);
+ }
+ VLOG(4) << "Adding key for " << hashCount << " type=" <<
+ (uint32_t)key->type_ << " Name=" << SSLUtil::hexlify(key->keyName_);
+
+ return key;
+}
+
+void
+TLSTicketKeyManager::hashNth(const unsigned char* input, size_t input_len,
+ unsigned char* output, uint32_t n) {
+ assert(n > 0);
+ for (uint32_t i = 0; i < n; i++) {
+ SHA256(input, input_len, output);
+ input = output;
+ input_len = SHA256_DIGEST_LENGTH;
+ }
+}
+
+TLSTicketKeyManager::TLSTicketSeed *
+TLSTicketKeyManager::insertSeed(const string& seedInput,
+ TLSTicketSeedType type) {
+ TLSTicketSeed* seed = nullptr;
+ string seedOutput;
+
+ if (!folly::unhexlify<string, string>(seedInput, seedOutput)) {
+ LOG(WARNING) << "Failed to decode seed type=" << (uint32_t)type <<
+ " seed=" << seedInput;
+ return seed;
+ }
+
+ seed = new TLSTicketSeed();
+ seed->seed_ = seedOutput;
+ seed->type_ = type;
+ SHA256((unsigned char *)seedOutput.data(), seedOutput.length(),
+ seed->seedName_);
+ ticketSeeds_.push_back(std::unique_ptr<TLSTicketSeed>(seed));
+
+ return seed;
+}
+
+TLSTicketKeyManager::TLSTicketKeySource *
+TLSTicketKeyManager::findEncryptionKey() {
+ TLSTicketKeySource* result = nullptr;
+ // call to rand here is a bit hokey since it's not cryptographically
+ // random, and is predictably seeded with 0. However, activeKeys_
+ // is probably not going to have very many keys in it, and most
+ // likely only 1.
+ size_t numKeys = activeKeys_.size();
+ if (numKeys > 0) {
+ result = activeKeys_[rand_r(&randState_) % numKeys];
+ }
+ return result;
+}
+
+TLSTicketKeyManager::TLSTicketKeySource *
+TLSTicketKeyManager::findDecryptionKey(unsigned char* keyName) {
+ string name((char *)keyName, kTLSTicketKeyNameLen);
+ TLSTicketKeySource* key = nullptr;
+ TLSTicketKeyMap::iterator mapit = ticketKeys_.find(name);
+ if (mapit != ticketKeys_.end()) {
+ key = mapit->second.get();
+ }
+ return key;
+}
+
+void
+TLSTicketKeyManager::makeUniqueKeys(unsigned char* parentKey,
+ size_t keyLen,
+ unsigned char* salt,
+ unsigned char* output) {
+ SHA256_CTX hash_ctx;
+
+ SHA256_Init(&hash_ctx);
+ SHA256_Update(&hash_ctx, parentKey, keyLen);
+ SHA256_Update(&hash_ctx, salt, kTLSTicketKeySaltLen);
+ SHA256_Final(output, &hash_ctx);
+}
+
+} // namespace
+#endif
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+#include <folly/io/async/SSLContext.h>
+#include <folly/io/async/EventBase.h>
+
+namespace folly {
+
+#ifndef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB
+class TLSTicketKeyManager {};
+#else
+class SSLStats;
+/**
+ * The TLSTicketKeyManager handles TLS ticket key encryption and decryption in
+ * a way that facilitates sharing the ticket keys across a range of servers.
+ * Hash chaining is employed to achieve frequent key rotation with minimal
+ * configuration change. The scheme is as follows:
+ *
+ * The manager is supplied with three lists of seeds (old, current and new).
+ * The config should be updated with new seeds periodically (e.g., daily).
+ * 3 config changes are recommended to achieve the smoothest seed rotation
+ * eg:
+ * 1. Introduce new seed in the push prior to rotation
+ * 2. Rotation push
+ * 3. Remove old seeds in the push following rotation
+ *
+ * Multiple seeds are supported but only a single seed is required.
+ *
+ * Generating encryption keys from the seed works as follows. For a given
+ * seed, hash forward N times where N is currently the constant 1.
+ * This is the base key. The name of the base key is the first 4
+ * bytes of hash(hash(seed), N). This is copied into the first 4 bytes of the
+ * TLS ticket key name field.
+ *
+ * For each new ticket encryption, the manager generates a random 12 byte salt.
+ * Hash the salt and the base key together to form the encryption key for
+ * that ticket. The salt is included in the ticket's 'key name' field so it
+ * can be used to derive the decryption key. The salt is copied into the second
+ * 8 bytes of the TLS ticket key name field.
+ *
+ * A key is valid for decryption for the lifetime of the instance.
+ * Sessions will be valid for less time than that, which results in an extra
+ * symmetric decryption to discover the session is expired.
+ *
+ * A TLSTicketKeyManager should be used in only one thread, and should have
+ * a 1:1 relationship with the SSLContext provided.
+ *
+ */
+class TLSTicketKeyManager : private boost::noncopyable {
+ public:
+
+ explicit TLSTicketKeyManager(folly::SSLContext* ctx,
+ SSLStats* stats);
+
+ virtual ~TLSTicketKeyManager();
+
+ /**
+ * SSL callback to set up encryption/decryption context for a TLS Ticket Key.
+ *
+ * This will be supplied to the SSL library via
+ * SSL_CTX_set_tlsext_ticket_key_cb.
+ */
+ static int callback(SSL* ssl, unsigned char* keyName,
+ unsigned char* iv,
+ EVP_CIPHER_CTX* cipherCtx,
+ HMAC_CTX* hmacCtx, int encrypt);
+
+ /**
+ * Initialize the manager with three sets of seeds. There must be at least
+ * one current seed, or the manager will revert to the default SSL behavior.
+ *
+ * @param oldSeeds Seeds previously used which can still decrypt.
+ * @param currentSeeds Seeds to use for new ticket encryptions.
+ * @param newSeeds Seeds which will be used soon, can be used to decrypt
+ * in case some servers in the cluster have already rotated.
+ */
+ bool setTLSTicketKeySeeds(const std::vector<std::string>& oldSeeds,
+ const std::vector<std::string>& currentSeeds,
+ const std::vector<std::string>& newSeeds);
+
+ private:
+ enum TLSTicketSeedType {
+ SEED_OLD = 0,
+ SEED_CURRENT,
+ SEED_NEW
+ };
+
+ /* The seeds supplied by the configuration */
+ struct TLSTicketSeed {
+ std::string seed_;
+ TLSTicketSeedType type_;
+ unsigned char seedName_[SHA256_DIGEST_LENGTH];
+ };
+
+ struct TLSTicketKeySource {
+ int32_t hashCount_;
+ std::string keyName_;
+ TLSTicketSeedType type_;
+ unsigned char keySource_[SHA256_DIGEST_LENGTH];
+ };
+
+ /**
+ * Method to setup encryption/decryption context for a TLS Ticket Key
+ *
+ * OpenSSL documentation is thin on the return value semantics.
+ *
+ * For encrypt=1, return < 0 on error, >= 0 for successfully initialized
+ * For encrypt=0, return < 0 on error, 0 on key not found
+ * 1 on key found, 2 renew_ticket
+ *
+ * renew_ticket means a new ticket will be issued. We could return this value
+ * when receiving a ticket encrypted with a key derived from an OLD seed.
+ * However, session_timeout seconds after deploying with a seed
+ * rotated from CURRENT -> OLD, there will be no valid tickets outstanding
+ * encrypted with the old key. This grace period means no unnecessary
+ * handshakes will be performed. If the seed is believed compromised, it
+ * should NOT be configured as an OLD seed.
+ */
+ int processTicket(SSL* ssl, unsigned char* keyName,
+ unsigned char* iv,
+ EVP_CIPHER_CTX* cipherCtx,
+ HMAC_CTX* hmacCtx, int encrypt);
+
+ // Creates the name for the nth key generated from seed
+ std::string makeKeyName(TLSTicketSeed* seed, uint32_t n,
+ unsigned char* nameBuf);
+
+ /**
+ * Creates the key hashCount hashes from the given seed and inserts it in
+ * ticketKeys. A naked pointer to the key is returned for additional
+ * processing if needed.
+ */
+ TLSTicketKeySource* insertNewKey(TLSTicketSeed* seed, uint32_t hashCount,
+ TLSTicketKeySource* prevKeySource);
+
+ /**
+ * hashes input N times placing result in output, which must be at least
+ * SHA256_DIGEST_LENGTH long.
+ */
+ void hashNth(const unsigned char* input, size_t input_len,
+ unsigned char* output, uint32_t n);
+
+ /**
+ * Adds the given seed to the manager
+ */
+ TLSTicketSeed* insertSeed(const std::string& seedInput,
+ TLSTicketSeedType type);
+
+ /**
+ * Locate a key for encrypting a new ticket
+ */
+ TLSTicketKeySource* findEncryptionKey();
+
+ /**
+ * Locate a key for decrypting a ticket with the given keyName
+ */
+ TLSTicketKeySource* findDecryptionKey(unsigned char* keyName);
+
+ /**
+ * Derive a unique key from the parent key and the salt via hashing
+ */
+ void makeUniqueKeys(unsigned char* parentKey, size_t keyLen,
+ unsigned char* salt, unsigned char* output);
+
+ /**
+ * For standalone decryption utility
+ */
+ friend int decrypt_fb_ticket(folly::TLSTicketKeyManager* manager,
+ const std::string& testTicket,
+ SSL_SESSION **psess);
+
+ typedef std::vector<std::unique_ptr<TLSTicketSeed>> TLSTicketSeedList;
+ typedef std::map<std::string, std::unique_ptr<TLSTicketKeySource> >
+ TLSTicketKeyMap;
+ typedef std::vector<TLSTicketKeySource *> TLSActiveKeyList;
+
+ TLSTicketSeedList ticketSeeds_;
+ // All key sources that can be used for decryption
+ TLSTicketKeyMap ticketKeys_;
+ // Key sources that can be used for encryption
+ TLSActiveKeyList activeKeys_;
+
+ folly::SSLContext* ctx_;
+ uint32_t randState_;
+ SSLStats* stats_{nullptr};
+
+ static int32_t sExDataIndex_;
+};
+#endif
+}
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#pragma once
+
+namespace folly {
+
+struct TLSTicketKeySeeds {
+ std::vector<std::string> oldSeeds;
+ std::vector<std::string> currentSeeds;
+ std::vector<std::string> newSeeds;
+};
+
+}
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/Portability.h>
+#include <folly/io/async/EventBase.h>
+#include <gflags/gflags.h>
+#include <iostream>
+#include <thread>
+#include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <vector>
+
+using namespace std;
+using namespace folly;
+
+DEFINE_int32(clients, 1, "Number of simulated SSL clients");
+DEFINE_int32(threads, 1, "Number of threads to spread clients across");
+DEFINE_int32(requests, 2, "Total number of requests per client");
+DEFINE_int32(port, 9423, "Server port");
+DEFINE_bool(sticky, false, "A given client sends all reqs to one "
+ "(random) server");
+DEFINE_bool(global, false, "All clients in a thread use the same SSL session");
+DEFINE_bool(handshakes, false, "Force 100% handshakes");
+
+string f_servers[10];
+int f_num_servers = 0;
+int tnum = 0;
+
+class ClientRunner {
+ public:
+
+ ClientRunner(): reqs(0), hits(0), miss(0), num(tnum++) {}
+ void run();
+
+ int reqs;
+ int hits;
+ int miss;
+ int num;
+};
+
+class SSLCacheClient : public AsyncSocket::ConnectCallback,
+ public AsyncSSLSocket::HandshakeCB
+{
+private:
+ EventBase* eventBase_;
+ int currReq_;
+ int serverIdx_;
+ AsyncSocket* socket_;
+ AsyncSSLSocket* sslSocket_;
+ SSL_SESSION* session_;
+ SSL_SESSION **pSess_;
+ std::shared_ptr<SSLContext> ctx_;
+ ClientRunner* cr_;
+
+public:
+ SSLCacheClient(EventBase* eventBase, SSL_SESSION **pSess, ClientRunner* cr);
+ ~SSLCacheClient() {
+ if (session_ && !FLAGS_global)
+ SSL_SESSION_free(session_);
+ if (socket_ != nullptr) {
+ if (sslSocket_ != nullptr) {
+ sslSocket_->destroy();
+ sslSocket_ = nullptr;
+ }
+ socket_->destroy();
+ socket_ = nullptr;
+ }
+ };
+
+ void start();
+
+ virtual void connectSuccess() noexcept;
+
+ virtual void connectErr(const AsyncSocketException& ex)
+ noexcept ;
+
+ virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept;
+
+ virtual void handshakeErr(
+ AsyncSSLSocket* sock,
+ const AsyncSocketException& ex) noexcept;
+
+};
+
+int
+main(int argc, char* argv[])
+{
+ gflags::SetUsageMessage(std::string("\n\n"
+"usage: sslcachetest [options] -c <clients> -t <threads> servers\n"
+));
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+ int reqs = 0;
+ int hits = 0;
+ int miss = 0;
+ struct timeval start;
+ struct timeval end;
+ struct timeval result;
+
+ srand((unsigned int)time(nullptr));
+
+ for (int i = 1; i < argc; i++) {
+ f_servers[f_num_servers++] = argv[i];
+ }
+ if (f_num_servers == 0) {
+ cout << "require at least one server\n";
+ return 1;
+ }
+
+ gettimeofday(&start, nullptr);
+ if (FLAGS_threads == 1) {
+ ClientRunner r;
+ r.run();
+ gettimeofday(&end, nullptr);
+ reqs = r.reqs;
+ hits = r.hits;
+ miss = r.miss;
+ }
+ else {
+ std::vector<ClientRunner> clients;
+ std::vector<std::thread> threads;
+ for (int t = 0; t < FLAGS_threads; t++) {
+ threads.emplace_back([&] {
+ clients[t].run();
+ });
+ }
+ for (auto& thr: threads) {
+ thr.join();
+ }
+ gettimeofday(&end, nullptr);
+
+ for (const auto& client: clients) {
+ reqs += client.reqs;
+ hits += client.hits;
+ miss += client.miss;
+ }
+ }
+
+ timersub(&end, &start, &result);
+
+ cout << "Requests: " << reqs << endl;
+ cout << "Handshakes: " << miss << endl;
+ cout << "Resumes: " << hits << endl;
+ cout << "Runtime(ms): " << result.tv_sec << "." << result.tv_usec / 1000 <<
+ endl;
+
+ cout << "ops/sec: " << (reqs * 1.0) /
+ ((double)result.tv_sec * 1.0 + (double)result.tv_usec / 1000000.0) << endl;
+
+ return 0;
+}
+
+void
+ClientRunner::run()
+{
+ EventBase eb;
+ std::list<SSLCacheClient *> clients;
+ SSL_SESSION* session = nullptr;
+
+ for (int i = 0; i < FLAGS_clients; i++) {
+ SSLCacheClient* c = new SSLCacheClient(&eb, &session, this);
+ c->start();
+ clients.push_back(c);
+ }
+
+ eb.loop();
+
+ for (auto it = clients.begin(); it != clients.end(); it++) {
+ delete* it;
+ }
+
+ reqs += hits + miss;
+}
+
+SSLCacheClient::SSLCacheClient(EventBase* eb,
+ SSL_SESSION **pSess,
+ ClientRunner* cr)
+ : eventBase_(eb),
+ currReq_(0),
+ serverIdx_(0),
+ socket_(nullptr),
+ sslSocket_(nullptr),
+ session_(nullptr),
+ pSess_(pSess),
+ cr_(cr)
+{
+ ctx_.reset(new SSLContext());
+ ctx_->setOptions(SSL_OP_NO_TICKET);
+}
+
+void
+SSLCacheClient::start()
+{
+ if (currReq_ >= FLAGS_requests) {
+ cout << "+";
+ return;
+ }
+
+ if (currReq_ == 0 || !FLAGS_sticky) {
+ serverIdx_ = rand() % f_num_servers;
+ }
+ if (socket_ != nullptr) {
+ if (sslSocket_ != nullptr) {
+ sslSocket_->destroy();
+ sslSocket_ = nullptr;
+ }
+ socket_->destroy();
+ socket_ = nullptr;
+ }
+ socket_ = new AsyncSocket(eventBase_);
+ socket_->connect(this, f_servers[serverIdx_], (uint16_t)FLAGS_port);
+}
+
+void
+SSLCacheClient::connectSuccess() noexcept
+{
+ sslSocket_ = new AsyncSSLSocket(ctx_, eventBase_, socket_->detachFd(),
+ false);
+
+ if (!FLAGS_handshakes) {
+ if (session_ != nullptr)
+ sslSocket_->setSSLSession(session_);
+ else if (FLAGS_global && pSess_ != nullptr)
+ sslSocket_->setSSLSession(*pSess_);
+ }
+ sslSocket_->sslConn(this);
+}
+
+void
+SSLCacheClient::connectErr(const AsyncSocketException& ex)
+ noexcept
+{
+ cout << "connectError: " << ex.what() << endl;
+}
+
+void
+SSLCacheClient::handshakeSuc(AsyncSSLSocket* socket) noexcept
+{
+ if (sslSocket_->getSSLSessionReused()) {
+ cr_->hits++;
+ } else {
+ cr_->miss++;
+ if (session_ != nullptr) {
+ SSL_SESSION_free(session_);
+ }
+ session_ = sslSocket_->getSSLSession();
+ if (FLAGS_global && pSess_ != nullptr && *pSess_ == nullptr) {
+ *pSess_ = session_;
+ }
+ }
+ if ( ((cr_->hits + cr_->miss) % 100) == ((100 / FLAGS_threads) * cr_->num)) {
+ cout << ".";
+ cout.flush();
+ }
+ sslSocket_->closeNow();
+ currReq_++;
+ this->start();
+}
+
+void
+SSLCacheClient::handshakeErr(
+ AsyncSSLSocket* sock,
+ const AsyncSocketException& ex)
+ noexcept
+{
+ cout << "handshakeError: " << ex.what() << endl;
+}
--- /dev/null
+/*
+ * Copyright (c) 2014, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ *
+ */
+#include <folly/io/async/EventBase.h>
+#include <folly/io/async/SSLContext.h>
+#include <glog/logging.h>
+#include <gtest/gtest.h>
+#include <folly/wangle/ssl/SSLContextManager.h>
+#include <folly/wangle/acceptor/DomainNameMisc.h>
+
+using std::shared_ptr;
+
+namespace folly {
+
+TEST(SSLContextManagerTest, Test1)
+{
+ EventBase eventBase;
+ SSLContextManager sslCtxMgr(&eventBase, "vip_ssl_context_manager_test_",
+ true, nullptr);
+ auto www_facebook_com_ctx = std::make_shared<SSLContext>();
+ auto start_facebook_com_ctx = std::make_shared<SSLContext>();
+ auto start_abc_facebook_com_ctx = std::make_shared<SSLContext>();
+
+ sslCtxMgr.insertSSLCtxByDomainName(
+ "www.facebook.com",
+ strlen("www.facebook.com"),
+ www_facebook_com_ctx);
+ sslCtxMgr.insertSSLCtxByDomainName(
+ "www.facebook.com",
+ strlen("www.facebook.com"),
+ www_facebook_com_ctx);
+ try {
+ sslCtxMgr.insertSSLCtxByDomainName(
+ "www.facebook.com",
+ strlen("www.facebook.com"),
+ std::make_shared<SSLContext>());
+ } catch (const std::exception& ex) {
+ }
+ sslCtxMgr.insertSSLCtxByDomainName(
+ "*.facebook.com",
+ strlen("*.facebook.com"),
+ start_facebook_com_ctx);
+ sslCtxMgr.insertSSLCtxByDomainName(
+ "*.abc.facebook.com",
+ strlen("*.abc.facebook.com"),
+ start_abc_facebook_com_ctx);
+ try {
+ sslCtxMgr.insertSSLCtxByDomainName(
+ "*.abc.facebook.com",
+ strlen("*.abc.facebook.com"),
+ std::make_shared<SSLContext>());
+ FAIL();
+ } catch (const std::exception& ex) {
+ }
+
+ shared_ptr<SSLContext> retCtx;
+ retCtx = sslCtxMgr.getSSLCtx(DNString("www.facebook.com"));
+ EXPECT_EQ(retCtx, www_facebook_com_ctx);
+ retCtx = sslCtxMgr.getSSLCtx(DNString("WWW.facebook.com"));
+ EXPECT_EQ(retCtx, www_facebook_com_ctx);
+ EXPECT_FALSE(sslCtxMgr.getSSLCtx(DNString("xyz.facebook.com")));
+
+ retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("xyz.facebook.com"));
+ EXPECT_EQ(retCtx, start_facebook_com_ctx);
+ retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("XYZ.facebook.com"));
+ EXPECT_EQ(retCtx, start_facebook_com_ctx);
+
+ retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("www.abc.facebook.com"));
+ EXPECT_EQ(retCtx, start_abc_facebook_com_ctx);
+
+ // ensure "facebook.com" does not match "*.facebook.com"
+ EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("facebook.com")));
+ // ensure "Xfacebook.com" does not match "*.facebook.com"
+ EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("Xfacebook.com")));
+ // ensure wildcard name only matches one domain up
+ EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("abc.xyz.facebook.com")));
+
+ eventBase.loop(); // Clean up events before SSLContextManager is destructed
+}
+
+}