From: James Sedgwick Date: Thu, 18 Dec 2014 20:20:26 +0000 (-0800) Subject: codemod: merge folly/wangle and folly/experimental/wangle X-Git-Tag: v0.22.0~86 X-Git-Url: http://demsky.eecs.uci.edu/git/?a=commitdiff_plain;h=5e8d8951109ed9bcc29fed28dfe62d9d51b3b6d4;p=folly.git codemod: merge folly/wangle and folly/experimental/wangle Summary: Various TARGETS could definitely use some rearranging but I'd rather wait until we're cut over to the new repo Test Plan: wait for contbuild Reviewed By: davejwatson@fb.com Subscribers: ptarjan, joelm, trunkagent, hphp-diffs@, ps, fbcode-common-diffs@, fugalh, alandau, bmatheny, everstore-dev@, mwa, jgehring, fuegen, mshneer, folly-diffs@, hannesr FB internal diff: D1740858 Tasks: 5802833 Signature: t1:1740858:1418752569:4d7d9c5b955e4d9fab4b322cf08a3d285e3db7ce --- diff --git a/folly/Makefile.am b/folly/Makefile.am index 4a044257..95cc3143 100644 --- a/folly/Makefile.am +++ b/folly/Makefile.am @@ -74,49 +74,6 @@ nobase_follyinclude_HEADERS = \ experimental/io/FsUtil.h \ experimental/Singleton.h \ experimental/TestUtil.h \ - experimental/wangle/channel/AsyncSocketHandler.h \ - experimental/wangle/channel/ChannelHandler.h \ - experimental/wangle/channel/ChannelHandlerContext.h \ - experimental/wangle/channel/ChannelPipeline.h \ - experimental/wangle/channel/OutputBufferingHandler.h \ - experimental/wangle/concurrent/BlockingQueue.h \ - experimental/wangle/concurrent/Codel.h \ - experimental/wangle/concurrent/CPUThreadPoolExecutor.h \ - experimental/wangle/concurrent/FutureExecutor.h \ - experimental/wangle/concurrent/GlobalExecutor.h \ - experimental/wangle/concurrent/IOExecutor.h \ - experimental/wangle/concurrent/IOThreadPoolExecutor.h \ - experimental/wangle/concurrent/LifoSemMPMCQueue.h \ - experimental/wangle/concurrent/NamedThreadFactory.h \ - experimental/wangle/concurrent/ThreadFactory.h \ - experimental/wangle/concurrent/ThreadPoolExecutor.h \ - experimental/wangle/rx/Observable.h \ - experimental/wangle/rx/Observer.h \ - experimental/wangle/rx/Subject.h \ - experimental/wangle/rx/Subscription.h \ - experimental/wangle/rx/types.h \ - experimental/wangle/ConnectionManager.h \ - experimental/wangle/ManagedConnection.h \ - experimental/wangle/acceptor/Acceptor.h \ - experimental/wangle/acceptor/ConnectionCounter.h \ - experimental/wangle/acceptor/SocketOptions.h \ - experimental/wangle/acceptor/DomainNameMisc.h \ - experimental/wangle/acceptor/LoadShedConfiguration.h \ - experimental/wangle/acceptor/NetworkAddress.h \ - experimental/wangle/acceptor/ServerSocketConfig.h \ - experimental/wangle/acceptor/TransportInfo.h \ - experimental/wangle/ssl/ClientHelloExtStats.h \ - experimental/wangle/ssl/DHParam.h \ - experimental/wangle/ssl/PasswordInFile.h \ - experimental/wangle/ssl/SSLCacheOptions.h \ - experimental/wangle/ssl/SSLCacheProvider.h \ - experimental/wangle/ssl/SSLContextConfig.h \ - experimental/wangle/ssl/SSLContextManager.h \ - experimental/wangle/ssl/SSLSessionCacheManager.h \ - experimental/wangle/ssl/SSLStats.h \ - experimental/wangle/ssl/SSLUtil.h \ - experimental/wangle/ssl/TLSTicketKeyManager.h \ - experimental/wangle/ssl/TLSTicketKeySeeds.h \ FBString.h \ FBVector.h \ File.h \ @@ -238,6 +195,32 @@ nobase_follyinclude_HEADERS = \ Uri-inl.h \ Varint.h \ VersionCheck.h \ + wangle/acceptor/Acceptor.h \ + wangle/acceptor/ConnectionCounter.h \ + wangle/acceptor/ConnectionManager.h \ + wangle/acceptor/DomainNameMisc.h \ + wangle/acceptor/LoadShedConfiguration.h \ + wangle/acceptor/ManagedConnection.h \ + wangle/acceptor/NetworkAddress.h \ + wangle/acceptor/ServerSocketConfig.h \ + wangle/acceptor/SocketOptions.h \ + wangle/acceptor/TransportInfo.h \ + wangle/channel/AsyncSocketHandler.h \ + wangle/channel/ChannelHandler.h \ + wangle/channel/ChannelHandlerContext.h \ + wangle/channel/ChannelPipeline.h \ + wangle/channel/OutputBufferingHandler.h \ + wangle/concurrent/BlockingQueue.h \ + wangle/concurrent/Codel.h \ + wangle/concurrent/CPUThreadPoolExecutor.h \ + wangle/concurrent/FutureExecutor.h \ + wangle/concurrent/IOExecutor.h \ + wangle/concurrent/IOThreadPoolExecutor.h \ + wangle/concurrent/GlobalExecutor.h \ + wangle/concurrent/LifoSemMPMCQueue.h \ + wangle/concurrent/NamedThreadFactory.h \ + wangle/concurrent/ThreadFactory.h \ + wangle/concurrent/ThreadPoolExecutor.h \ wangle/futures/Deprecated.h \ wangle/futures/Future-inl.h \ wangle/futures/Future.h \ @@ -252,7 +235,24 @@ nobase_follyinclude_HEADERS = \ wangle/futures/Try.h \ wangle/futures/WangleException.h \ wangle/futures/detail/Core.h \ - wangle/futures/detail/FSM.h + wangle/futures/detail/FSM.h \ + wangle/rx/Observable.h \ + wangle/rx/Observer.h \ + wangle/rx/Subject.h \ + wangle/rx/Subscription.h \ + wangle/rx/types.h \ + wangle/ssl/ClientHelloExtStats.h \ + wangle/ssl/DHParam.h \ + wangle/ssl/PasswordInFile.h \ + wangle/ssl/SSLCacheOptions.h \ + wangle/ssl/SSLCacheProvider.h \ + wangle/ssl/SSLContextConfig.h \ + wangle/ssl/SSLContextManager.h \ + wangle/ssl/SSLSessionCacheManager.h \ + wangle/ssl/SSLStats.h \ + wangle/ssl/SSLUtil.h \ + wangle/ssl/TLSTicketKeyManager.h \ + wangle/ssl/TLSTicketKeySeeds.h FormatTables.cpp: build/generate_format_tables.py build/generate_format_tables.py @@ -324,28 +324,28 @@ libfolly_la_SOURCES = \ TimeoutQueue.cpp \ Uri.cpp \ Version.cpp \ - wangle/futures/InlineExecutor.cpp \ - wangle/futures/ManualExecutor.cpp \ experimental/io/FsUtil.cpp \ experimental/Singleton.cpp \ experimental/TestUtil.cpp \ - experimental/wangle/concurrent/CPUThreadPoolExecutor.cpp \ - experimental/wangle/concurrent/Codel.cpp \ - experimental/wangle/concurrent/GlobalExecutor.cpp \ - experimental/wangle/concurrent/IOExecutor.cpp \ - experimental/wangle/concurrent/IOThreadPoolExecutor.cpp \ - experimental/wangle/concurrent/ThreadPoolExecutor.cpp \ - experimental/wangle/ConnectionManager.cpp \ - experimental/wangle/ManagedConnection.cpp \ - experimental/wangle/acceptor/Acceptor.cpp \ - experimental/wangle/acceptor/SocketOptions.cpp \ - experimental/wangle/acceptor/LoadShedConfiguration.cpp \ - experimental/wangle/acceptor/TransportInfo.cpp \ - experimental/wangle/ssl/PasswordInFile.cpp \ - experimental/wangle/ssl/SSLContextManager.cpp \ - experimental/wangle/ssl/SSLSessionCacheManager.cpp \ - experimental/wangle/ssl/SSLUtil.cpp \ - experimental/wangle/ssl/TLSTicketKeyManager.cpp + wangle/acceptor/Acceptor.cpp \ + wangle/acceptor/ConnectionManager.cpp \ + wangle/acceptor/LoadShedConfiguration.cpp \ + wangle/acceptor/ManagedConnection.cpp \ + wangle/acceptor/SocketOptions.cpp \ + wangle/acceptor/TransportInfo.cpp \ + wangle/concurrent/CPUThreadPoolExecutor.cpp \ + wangle/concurrent/Codel.cpp \ + wangle/concurrent/IOExecutor.cpp \ + wangle/concurrent/IOThreadPoolExecutor.cpp \ + wangle/concurrent/GlobalExecutor.cpp \ + wangle/concurrent/ThreadPoolExecutor.cpp \ + wangle/futures/InlineExecutor.cpp \ + wangle/futures/ManualExecutor.cpp \ + wangle/ssl/PasswordInFile.cpp \ + wangle/ssl/SSLContextManager.cpp \ + wangle/ssl/SSLSessionCacheManager.cpp \ + wangle/ssl/SSLUtil.cpp \ + wangle/ssl/TLSTicketKeyManager.cpp if HAVE_LINUX nobase_follyinclude_HEADERS += \ diff --git a/folly/experimental/wangle/ConnectionManager.cpp b/folly/experimental/wangle/ConnectionManager.cpp deleted file mode 100644 index 8489fa08..00000000 --- a/folly/experimental/wangle/ConnectionManager.cpp +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include - -using folly::HHWheelTimer; -using std::chrono::milliseconds; - -namespace folly { namespace wangle { - -ConnectionManager::ConnectionManager(EventBase* eventBase, - milliseconds timeout, Callback* callback) - : connTimeouts_(new HHWheelTimer(eventBase)), - callback_(callback), - eventBase_(eventBase), - idleIterator_(conns_.end()), - idleLoopCallback_(this), - timeout_(timeout) { - -} - -void -ConnectionManager::addConnection(ManagedConnection* connection, - bool timeout) { - CHECK_NOTNULL(connection); - ConnectionManager* oldMgr = connection->getConnectionManager(); - if (oldMgr != this) { - if (oldMgr) { - // 'connection' was being previously managed in a different thread. - // We must remove it from that manager before adding it to this one. - oldMgr->removeConnection(connection); - } - conns_.push_back(*connection); - connection->setConnectionManager(this); - if (callback_) { - callback_->onConnectionAdded(*this); - } - } - if (timeout) { - scheduleTimeout(connection); - } -} - -void -ConnectionManager::scheduleTimeout(ManagedConnection* connection) { - if (timeout_ > std::chrono::milliseconds(0)) { - connTimeouts_->scheduleTimeout(connection, timeout_); - } -} - -void ConnectionManager::scheduleTimeout( - folly::HHWheelTimer::Callback* callback, - std::chrono::milliseconds timeout) { - connTimeouts_->scheduleTimeout(callback, timeout); -} - -void -ConnectionManager::removeConnection(ManagedConnection* connection) { - if (connection->getConnectionManager() == this) { - connection->cancelTimeout(); - connection->setConnectionManager(nullptr); - - // Un-link the connection from our list, being careful to keep the iterator - // that we're using for idle shedding valid - auto it = conns_.iterator_to(*connection); - if (it == idleIterator_) { - ++idleIterator_; - } - conns_.erase(it); - - if (callback_) { - callback_->onConnectionRemoved(*this); - if (getNumConnections() == 0) { - callback_->onEmpty(*this); - } - } - } -} - -void -ConnectionManager::initiateGracefulShutdown( - std::chrono::milliseconds idleGrace) { - if (idleGrace.count() > 0) { - idleLoopCallback_.scheduleTimeout(idleGrace); - VLOG(3) << "Scheduling idle grace period of " << idleGrace.count() << "ms"; - } else { - action_ = ShutdownAction::DRAIN2; - VLOG(3) << "proceeding directly to closing idle connections"; - } - drainAllConnections(); -} - -void -ConnectionManager::drainAllConnections() { - DestructorGuard g(this); - size_t numCleared = 0; - size_t numKept = 0; - - auto it = idleIterator_ == conns_.end() ? - conns_.begin() : idleIterator_; - - while (it != conns_.end() && (numKept + numCleared) < 64) { - ManagedConnection& conn = *it++; - if (action_ == ShutdownAction::DRAIN1) { - conn.notifyPendingShutdown(); - } else { - // Second time around: close idle sessions. If they aren't idle yet, - // have them close when they are idle - if (conn.isBusy()) { - numKept++; - } else { - numCleared++; - } - conn.closeWhenIdle(); - } - } - - if (action_ == ShutdownAction::DRAIN2) { - VLOG(2) << "Idle connections cleared: " << numCleared << - ", busy conns kept: " << numKept; - } - if (it != conns_.end()) { - idleIterator_ = it; - eventBase_->runInLoop(&idleLoopCallback_); - } else { - action_ = ShutdownAction::DRAIN2; - } -} - -void -ConnectionManager::dropAllConnections() { - DestructorGuard g(this); - - // Iterate through our connection list, and drop each connection. - VLOG(3) << "connections to drop: " << conns_.size(); - idleLoopCallback_.cancelTimeout(); - unsigned i = 0; - while (!conns_.empty()) { - ManagedConnection& conn = conns_.front(); - conns_.pop_front(); - conn.cancelTimeout(); - conn.setConnectionManager(nullptr); - // For debugging purposes, dump information about the first few - // connections. - static const unsigned MAX_CONNS_TO_DUMP = 2; - if (++i <= MAX_CONNS_TO_DUMP) { - conn.dumpConnectionState(3); - } - conn.dropConnection(); - } - idleIterator_ = conns_.end(); - idleLoopCallback_.cancelLoopCallback(); - - if (callback_) { - callback_->onEmpty(*this); - } -} - -}} // folly::wangle diff --git a/folly/experimental/wangle/ConnectionManager.h b/folly/experimental/wangle/ConnectionManager.h deleted file mode 100644 index 9ac356d6..00000000 --- a/folly/experimental/wangle/ConnectionManager.h +++ /dev/null @@ -1,200 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -#include -#include -#include -#include -#include -#include - -namespace folly { namespace wangle { - -/** - * A ConnectionManager keeps track of ManagedConnections. - */ -class ConnectionManager: public folly::DelayedDestruction { - public: - - /** - * Interface for an optional observer that's notified about - * various events in a ConnectionManager - */ - class Callback { - public: - virtual ~Callback() {} - - /** - * Invoked when the number of connections managed by the - * ConnectionManager changes from nonzero to zero. - */ - virtual void onEmpty(const ConnectionManager& cm) = 0; - - /** - * Invoked when a connection is added to the ConnectionManager. - */ - virtual void onConnectionAdded(const ConnectionManager& cm) = 0; - - /** - * Invoked when a connection is removed from the ConnectionManager. - */ - virtual void onConnectionRemoved(const ConnectionManager& cm) = 0; - }; - - typedef std::unique_ptr UniquePtr; - - /** - * Returns a new instance of ConnectionManager wrapped in a unique_ptr - */ - template - static UniquePtr makeUnique(Args&&... args) { - return folly::make_unique( - std::forward(args)...); - } - - /** - * Constructor not to be used by itself. - */ - ConnectionManager(folly::EventBase* eventBase, - std::chrono::milliseconds timeout, - Callback* callback = nullptr); - - /** - * Add a connection to the set of connections managed by this - * ConnectionManager. - * - * @param connection The connection to add. - * @param timeout Whether to immediately register this connection - * for an idle timeout callback. - */ - void addConnection(ManagedConnection* connection, - bool timeout = false); - - /** - * Schedule a timeout callback for a connection. - */ - void scheduleTimeout(ManagedConnection* connection); - - /* - * Schedule a callback on the wheel timer - */ - void scheduleTimeout(folly::HHWheelTimer::Callback* callback, - std::chrono::milliseconds timeout); - - /** - * Remove a connection from this ConnectionManager and, if - * applicable, cancel the pending timeout callback that the - * ConnectionManager has scheduled for the connection. - * - * @note This method does NOT destroy the connection. - */ - void removeConnection(ManagedConnection* connection); - - /* Begin gracefully shutting down connections in this ConnectionManager. - * Notify all connections of pending shutdown, and after idleGrace, - * begin closing idle connections. - */ - void initiateGracefulShutdown(std::chrono::milliseconds idleGrace); - - /** - * Destroy all connections Managed by this ConnectionManager, even - * the ones that are busy. - */ - void dropAllConnections(); - - size_t getNumConnections() const { return conns_.size(); } - - template - void iterateConns(F func) { - auto it = conns_.begin(); - while ( it != conns_.end()) { - func(&(*it)); - it++; - } - } - - private: - class CloseIdleConnsCallback : - public folly::EventBase::LoopCallback, - public folly::AsyncTimeout { - public: - explicit CloseIdleConnsCallback(ConnectionManager* manager) - : folly::AsyncTimeout(manager->eventBase_), - manager_(manager) {} - - void runLoopCallback() noexcept override { - VLOG(3) << "Draining more conns from loop callback"; - manager_->drainAllConnections(); - } - - void timeoutExpired() noexcept override { - VLOG(3) << "Idle grace expired"; - manager_->drainAllConnections(); - } - - private: - ConnectionManager* manager_; - }; - - enum class ShutdownAction : uint8_t { - /** - * Drain part 1: inform remote that you will soon reject new requests. - */ - DRAIN1 = 0, - /** - * Drain part 2: start rejecting new requests. - */ - DRAIN2 = 1, - }; - - ~ConnectionManager() {} - - ConnectionManager(const ConnectionManager&) = delete; - ConnectionManager& operator=(ConnectionManager&) = delete; - - /** - * Destroy all connections managed by this ConnectionManager that - * are currently idle, as determined by a call to each ManagedConnection's - * isBusy() method. - */ - void drainAllConnections(); - - /** All connections */ - folly::CountedIntrusiveList< - ManagedConnection,&ManagedConnection::listHook_> conns_; - - /** Connections that currently are registered for timeouts */ - folly::HHWheelTimer::UniquePtr connTimeouts_; - - /** Optional callback to notify of state changes */ - Callback* callback_; - - /** Event base in which we run */ - folly::EventBase* eventBase_; - - /** Iterator to the next connection to shed; used by drainAllConnections() */ - folly::CountedIntrusiveList< - ManagedConnection,&ManagedConnection::listHook_>::iterator idleIterator_; - CloseIdleConnsCallback idleLoopCallback_; - ShutdownAction action_{ShutdownAction::DRAIN1}; - std::chrono::milliseconds timeout_; -}; - -}} // folly::wangle diff --git a/folly/experimental/wangle/ManagedConnection.cpp b/folly/experimental/wangle/ManagedConnection.cpp deleted file mode 100644 index 66db04f2..00000000 --- a/folly/experimental/wangle/ManagedConnection.cpp +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include - -namespace folly { namespace wangle { - -ManagedConnection::ManagedConnection() - : connectionManager_(nullptr) { -} - -ManagedConnection::~ManagedConnection() { - if (connectionManager_) { - connectionManager_->removeConnection(this); - } -} - -void -ManagedConnection::resetTimeout() { - if (connectionManager_) { - connectionManager_->scheduleTimeout(this); - } -} - -void -ManagedConnection::scheduleTimeout( - folly::HHWheelTimer::Callback* callback, - std::chrono::milliseconds timeout) { - if (connectionManager_) { - connectionManager_->scheduleTimeout(callback, timeout); - } -} - -////////////////////// Globals ///////////////////// - -std::ostream& -operator<<(std::ostream& os, const ManagedConnection& conn) { - conn.describe(os); - return os; -} - -}} // folly::wangle diff --git a/folly/experimental/wangle/ManagedConnection.h b/folly/experimental/wangle/ManagedConnection.h deleted file mode 100644 index 50e7c057..00000000 --- a/folly/experimental/wangle/ManagedConnection.h +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -namespace folly { namespace wangle { - -class ConnectionManager; - -/** - * Interface describing a connection that can be managed by a - * container such as an Acceptor. - */ -class ManagedConnection: - public folly::HHWheelTimer::Callback, - public folly::DelayedDestruction { - public: - - ManagedConnection(); - - // HHWheelTimer::Callback API (left for subclasses to implement). - virtual void timeoutExpired() noexcept = 0; - - /** - * Print a human-readable description of the connection. - * @param os Destination stream. - */ - virtual void describe(std::ostream& os) const = 0; - - /** - * Check whether the connection has any requests outstanding. - */ - virtual bool isBusy() const = 0; - - /** - * Notify the connection that a shutdown is pending. This method will be - * called at the beginning of graceful shutdown. - */ - virtual void notifyPendingShutdown() = 0; - - /** - * Instruct the connection that it should shutdown as soon as it is - * safe. This is called after notifyPendingShutdown(). - */ - virtual void closeWhenIdle() = 0; - - /** - * Forcibly drop a connection. - * - * If a request is in progress, this should cause the connection to be - * closed with a reset. - */ - virtual void dropConnection() = 0; - - /** - * Dump the state of the connection to the log - */ - virtual void dumpConnectionState(uint8_t loglevel) = 0; - - /** - * If the connection has a connection manager, reset the timeout - * countdown. - * @note If the connection manager doesn't have the connection scheduled - * for a timeout already, this method will schedule one. If the - * connection manager does have the connection connection scheduled - * for a timeout, this method will push back the timeout to N msec - * from now, where N is the connection manager's timer interval. - */ - virtual void resetTimeout(); - - // Schedule an arbitrary timeout on the HHWheelTimer - virtual void scheduleTimeout( - folly::HHWheelTimer::Callback* callback, - std::chrono::milliseconds timeout); - - ConnectionManager* getConnectionManager() { - return connectionManager_; - } - - protected: - virtual ~ManagedConnection(); - - private: - friend class ConnectionManager; - - void setConnectionManager(ConnectionManager* mgr) { - connectionManager_ = mgr; - } - - ConnectionManager* connectionManager_; - - folly::SafeIntrusiveListHook listHook_; -}; - -std::ostream& operator<<(std::ostream& os, const ManagedConnection& conn); - -}} // folly::wangle diff --git a/folly/experimental/wangle/acceptor/Acceptor.cpp b/folly/experimental/wangle/acceptor/Acceptor.cpp deleted file mode 100644 index c3b46b07..00000000 --- a/folly/experimental/wangle/acceptor/Acceptor.cpp +++ /dev/null @@ -1,437 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#include - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using folly::wangle::ConnectionManager; -using folly::wangle::ManagedConnection; -using std::chrono::microseconds; -using std::chrono::milliseconds; -using std::filebuf; -using std::ifstream; -using std::ios; -using std::shared_ptr; -using std::string; - -namespace folly { - -#ifndef NO_LIB_GFLAGS -DEFINE_int32(shutdown_idle_grace_ms, 5000, "milliseconds to wait before " - "closing idle conns"); -#else -const int32_t FLAGS_shutdown_idle_grace_ms = 5000; -#endif - -static const std::string empty_string; -std::atomic Acceptor::totalNumPendingSSLConns_{0}; - -/** - * Lightweight wrapper class to keep track of a newly - * accepted connection during SSL handshaking. - */ -class AcceptorHandshakeHelper : - public AsyncSSLSocket::HandshakeCB, - public ManagedConnection { - public: - AcceptorHandshakeHelper(AsyncSSLSocket::UniquePtr socket, - Acceptor* acceptor, - const SocketAddress& clientAddr, - std::chrono::steady_clock::time_point acceptTime) - : socket_(std::move(socket)), acceptor_(acceptor), - acceptTime_(acceptTime), clientAddr_(clientAddr) { - acceptor_->downstreamConnectionManager_->addConnection(this, true); - if(acceptor_->parseClientHello_) { - socket_->enableClientHelloParsing(); - } - socket_->sslAccept(this); - } - - virtual void timeoutExpired() noexcept { - VLOG(4) << "SSL handshake timeout expired"; - sslError_ = SSLErrorEnum::TIMEOUT; - dropConnection(); - } - virtual void describe(std::ostream& os) const { - os << "pending handshake on " << clientAddr_; - } - virtual bool isBusy() const { - return true; - } - virtual void notifyPendingShutdown() {} - virtual void closeWhenIdle() {} - - virtual void dropConnection() { - VLOG(10) << "Dropping in progress handshake for " << clientAddr_; - socket_->closeNow(); - } - virtual void dumpConnectionState(uint8_t loglevel) { - } - - private: - // AsyncSSLSocket::HandshakeCallback API - virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept { - - const unsigned char* nextProto = nullptr; - unsigned nextProtoLength = 0; - sock->getSelectedNextProtocol(&nextProto, &nextProtoLength); - if (VLOG_IS_ON(3)) { - if (nextProto) { - VLOG(3) << "Client selected next protocol " << - string((const char*)nextProto, nextProtoLength); - } else { - VLOG(3) << "Client did not select a next protocol"; - } - } - - // fill in SSL-related fields from TransportInfo - // the other fields like RTT are filled in the Acceptor - TransportInfo tinfo; - tinfo.ssl = true; - tinfo.acceptTime = acceptTime_; - tinfo.sslSetupTime = std::chrono::duration_cast(std::chrono::steady_clock::now() - acceptTime_); - tinfo.sslSetupBytesRead = sock->getRawBytesReceived(); - tinfo.sslSetupBytesWritten = sock->getRawBytesWritten(); - tinfo.sslServerName = sock->getSSLServerName(); - tinfo.sslCipher = sock->getNegotiatedCipherName(); - tinfo.sslVersion = sock->getSSLVersion(); - tinfo.sslCertSize = sock->getSSLCertSize(); - tinfo.sslResume = SSLUtil::getResumeState(sock); - sock->getSSLClientCiphers(tinfo.sslClientCiphers); - sock->getSSLServerCiphers(tinfo.sslServerCiphers); - tinfo.sslClientComprMethods = sock->getSSLClientComprMethods(); - tinfo.sslClientExts = sock->getSSLClientExts(); - tinfo.sslNextProtocol.assign( - reinterpret_cast(nextProto), - nextProtoLength); - - acceptor_->updateSSLStats(sock, tinfo.sslSetupTime, SSLErrorEnum::NO_ERROR); - acceptor_->downstreamConnectionManager_->removeConnection(this); - acceptor_->sslConnectionReady(std::move(socket_), clientAddr_, - nextProto ? string((const char*)nextProto, nextProtoLength) : - empty_string, tinfo); - delete this; - } - - virtual void handshakeErr(AsyncSSLSocket* sock, - const AsyncSocketException& ex) noexcept { - auto elapsedTime = std::chrono::duration_cast(std::chrono::steady_clock::now() - acceptTime_); - VLOG(3) << "SSL handshake error after " << elapsedTime.count() << - " ms; " << sock->getRawBytesReceived() << " bytes received & " << - sock->getRawBytesWritten() << " bytes sent: " << - ex.what(); - acceptor_->updateSSLStats(sock, elapsedTime, sslError_); - acceptor_->sslConnectionError(); - delete this; - } - - AsyncSSLSocket::UniquePtr socket_; - Acceptor* acceptor_; - std::chrono::steady_clock::time_point acceptTime_; - SocketAddress clientAddr_; - SSLErrorEnum sslError_{SSLErrorEnum::NO_ERROR}; -}; - -Acceptor::Acceptor(const ServerSocketConfig& accConfig) : - accConfig_(accConfig), - socketOptions_(accConfig.getSocketOptions()) { -} - -void -Acceptor::init(AsyncServerSocket* serverSocket, - EventBase* eventBase) { - CHECK(nullptr == this->base_); - - if (accConfig_.isSSL()) { - if (!sslCtxManager_) { - sslCtxManager_ = folly::make_unique( - eventBase, - "vip_" + getName(), - accConfig_.strictSSL, nullptr); - } - for (const auto& sslCtxConfig : accConfig_.sslContextConfigs) { - sslCtxManager_->addSSLContextConfig( - sslCtxConfig, - accConfig_.sslCacheOptions, - &accConfig_.initialTicketSeeds, - accConfig_.bindAddress, - cacheProvider_); - parseClientHello_ |= sslCtxConfig.clientHelloParsingEnabled; - } - - CHECK(sslCtxManager_->getDefaultSSLCtx()); - } - - base_ = eventBase; - state_ = State::kRunning; - downstreamConnectionManager_ = ConnectionManager::makeUnique( - eventBase, accConfig_.connectionIdleTimeout, this); - - if (serverSocket) { - serverSocket->addAcceptCallback(this, eventBase); - // SO_KEEPALIVE is the only setting that is inherited by accepted - // connections so only apply this setting - for (const auto& option: socketOptions_) { - if (option.first.level == SOL_SOCKET && - option.first.optname == SO_KEEPALIVE && option.second == 1) { - serverSocket->setKeepAliveEnabled(true); - break; - } - } - } -} - -Acceptor::~Acceptor(void) { -} - -void Acceptor::addSSLContextConfig(const SSLContextConfig& sslCtxConfig) { - sslCtxManager_->addSSLContextConfig(sslCtxConfig, - accConfig_.sslCacheOptions, - &accConfig_.initialTicketSeeds, - accConfig_.bindAddress, - cacheProvider_); -} - -void -Acceptor::drainAllConnections() { - if (downstreamConnectionManager_) { - downstreamConnectionManager_->initiateGracefulShutdown( - std::chrono::milliseconds(FLAGS_shutdown_idle_grace_ms)); - } -} - -void Acceptor::setLoadShedConfig(const LoadShedConfiguration& from, - IConnectionCounter* counter) { - loadShedConfig_ = from; - connectionCounter_ = counter; -} - -bool Acceptor::canAccept(const SocketAddress& address) { - if (!connectionCounter_) { - return true; - } - - uint64_t maxConnections = connectionCounter_->getMaxConnections(); - if (maxConnections == 0) { - return true; - } - - uint64_t currentConnections = connectionCounter_->getNumConnections(); - if (currentConnections < maxConnections) { - return true; - } - - if (loadShedConfig_.isWhitelisted(address)) { - return true; - } - - // Take care of comparing connection count against max connections across - // all acceptors. Expensive since a lock must be taken to get the counter. - auto connectionCountForLoadShedding = getConnectionCountForLoadShedding(); - if (connectionCountForLoadShedding < loadShedConfig_.getMaxConnections()) { - return true; - } - - VLOG(4) << address.describe() << " not whitelisted"; - return false; -} - -void -Acceptor::connectionAccepted( - int fd, const SocketAddress& clientAddr) noexcept { - if (!canAccept(clientAddr)) { - close(fd); - return; - } - auto acceptTime = std::chrono::steady_clock::now(); - for (const auto& opt: socketOptions_) { - opt.first.apply(fd, opt.second); - } - - onDoneAcceptingConnection(fd, clientAddr, acceptTime); -} - -void Acceptor::onDoneAcceptingConnection( - int fd, - const SocketAddress& clientAddr, - std::chrono::steady_clock::time_point acceptTime) noexcept { - processEstablishedConnection(fd, clientAddr, acceptTime); -} - -void -Acceptor::processEstablishedConnection( - int fd, - const SocketAddress& clientAddr, - std::chrono::steady_clock::time_point acceptTime) noexcept { - if (accConfig_.isSSL()) { - CHECK(sslCtxManager_); - AsyncSSLSocket::UniquePtr sslSock( - makeNewAsyncSSLSocket( - sslCtxManager_->getDefaultSSLCtx(), base_, fd)); - ++numPendingSSLConns_; - ++totalNumPendingSSLConns_; - if (totalNumPendingSSLConns_ > accConfig_.maxConcurrentSSLHandshakes) { - VLOG(2) << "dropped SSL handshake on " << accConfig_.name << - " too many handshakes in progress"; - updateSSLStats(sslSock.get(), std::chrono::milliseconds(0), - SSLErrorEnum::DROPPED); - sslConnectionError(); - return; - } - new AcceptorHandshakeHelper( - std::move(sslSock), this, clientAddr, acceptTime); - } else { - TransportInfo tinfo; - tinfo.ssl = false; - tinfo.acceptTime = acceptTime; - AsyncSocket::UniquePtr sock(makeNewAsyncSocket(base_, fd)); - connectionReady(std::move(sock), clientAddr, empty_string, tinfo); - } -} - -void -Acceptor::connectionReady( - AsyncSocket::UniquePtr sock, - const SocketAddress& clientAddr, - const string& nextProtocolName, - TransportInfo& tinfo) { - // Limit the number of reads from the socket per poll loop iteration, - // both to keep memory usage under control and to prevent one fast- - // writing client from starving other connections. - sock->setMaxReadsPerEvent(16); - tinfo.initWithSocket(sock.get()); - onNewConnection(std::move(sock), &clientAddr, nextProtocolName, tinfo); -} - -void -Acceptor::sslConnectionReady(AsyncSocket::UniquePtr sock, - const SocketAddress& clientAddr, - const string& nextProtocol, - TransportInfo& tinfo) { - CHECK(numPendingSSLConns_ > 0); - connectionReady(std::move(sock), clientAddr, nextProtocol, tinfo); - --numPendingSSLConns_; - --totalNumPendingSSLConns_; - if (state_ == State::kDraining) { - checkDrained(); - } -} - -void -Acceptor::sslConnectionError() { - CHECK(numPendingSSLConns_ > 0); - --numPendingSSLConns_; - --totalNumPendingSSLConns_; - if (state_ == State::kDraining) { - checkDrained(); - } -} - -void -Acceptor::acceptError(const std::exception& ex) noexcept { - // An error occurred. - // The most likely error is out of FDs. AsyncServerSocket will back off - // briefly if we are out of FDs, then continue accepting later. - // Just log a message here. - LOG(ERROR) << "error accepting on acceptor socket: " << ex.what(); -} - -void -Acceptor::acceptStopped() noexcept { - VLOG(3) << "Acceptor " << this << " acceptStopped()"; - // Drain the open client connections - drainAllConnections(); - - // If we haven't yet finished draining, begin doing so by marking ourselves - // as in the draining state. We must be sure to hit checkDrained() here, as - // if we're completely idle, we can should consider ourself drained - // immediately (as there is no outstanding work to complete to cause us to - // re-evaluate this). - if (state_ != State::kDone) { - state_ = State::kDraining; - checkDrained(); - } -} - -void -Acceptor::onEmpty(const ConnectionManager& cm) { - VLOG(3) << "Acceptor=" << this << " onEmpty()"; - if (state_ == State::kDraining) { - checkDrained(); - } -} - -void -Acceptor::checkDrained() { - CHECK(state_ == State::kDraining); - if (forceShutdownInProgress_ || - (downstreamConnectionManager_->getNumConnections() != 0) || - (numPendingSSLConns_ != 0)) { - return; - } - - VLOG(2) << "All connections drained from Acceptor=" << this << " in thread " - << base_; - - downstreamConnectionManager_.reset(); - - state_ = State::kDone; - - onConnectionsDrained(); -} - -milliseconds -Acceptor::getConnTimeout() const { - return accConfig_.connectionIdleTimeout; -} - -void Acceptor::addConnection(ManagedConnection* conn) { - // Add the socket to the timeout manager so that it can be cleaned - // up after being left idle for a long time. - downstreamConnectionManager_->addConnection(conn, true); -} - -void -Acceptor::forceStop() { - base_->runInEventBaseThread([&] { dropAllConnections(); }); -} - -void -Acceptor::dropAllConnections() { - if (downstreamConnectionManager_) { - VLOG(3) << "Dropping all connections from Acceptor=" << this << - " in thread " << base_; - assert(base_->isInEventBaseThread()); - forceShutdownInProgress_ = true; - downstreamConnectionManager_->dropAllConnections(); - CHECK(downstreamConnectionManager_->getNumConnections() == 0); - downstreamConnectionManager_.reset(); - } - CHECK(numPendingSSLConns_ == 0); - - state_ = State::kDone; - onConnectionsDrained(); -} - -} // namespace diff --git a/folly/experimental/wangle/acceptor/Acceptor.h b/folly/experimental/wangle/acceptor/Acceptor.h deleted file mode 100644 index 404425e7..00000000 --- a/folly/experimental/wangle/acceptor/Acceptor.h +++ /dev/null @@ -1,346 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -#include "folly/experimental/wangle/acceptor/ServerSocketConfig.h" -#include "folly/experimental/wangle/acceptor/ConnectionCounter.h" -#include -#include "folly/experimental/wangle/acceptor/LoadShedConfiguration.h" -#include "folly/experimental/wangle/ssl/SSLCacheProvider.h" -#include "folly/experimental/wangle/acceptor/TransportInfo.h" - -#include -#include -#include -#include - -namespace folly { namespace wangle { -class ManagedConnection; -}} - -namespace folly { - -class SocketAddress; -class SSLContext; -class AsyncTransport; -class SSLContextManager; - -/** - * An abstract acceptor for TCP-based network services. - * - * There is one acceptor object per thread for each listening socket. When a - * new connection arrives on the listening socket, it is accepted by one of the - * acceptor objects. From that point on the connection will be processed by - * that acceptor's thread. - * - * The acceptor will call the abstract onNewConnection() method to create - * a new ManagedConnection object for each accepted socket. The acceptor - * also tracks all outstanding connections that it has accepted. - */ -class Acceptor : - public folly::AsyncServerSocket::AcceptCallback, - public folly::wangle::ConnectionManager::Callback { - public: - - enum class State : uint32_t { - kInit, // not yet started - kRunning, // processing requests normally - kDraining, // processing outstanding conns, but not accepting new ones - kDone, // no longer accepting, and all connections finished - }; - - explicit Acceptor(const ServerSocketConfig& accConfig); - virtual ~Acceptor(); - - /** - * Supply an SSL cache provider - * @note Call this before init() - */ - virtual void setSSLCacheProvider( - const std::shared_ptr& cacheProvider) { - cacheProvider_ = cacheProvider; - } - - /** - * Initialize the Acceptor to run in the specified EventBase - * thread, receiving connections from the specified AsyncServerSocket. - * - * This method will be called from the AsyncServerSocket's primary thread, - * not the specified EventBase thread. - */ - virtual void init(AsyncServerSocket* serverSocket, - EventBase* eventBase); - - /** - * Dynamically add a new SSLContextConfig - */ - void addSSLContextConfig(const SSLContextConfig& sslCtxConfig); - - SSLContextManager* getSSLContextManager() const { - return sslCtxManager_.get(); - } - - /** - * Return the number of outstanding connections in this service instance. - */ - uint32_t getNumConnections() const { - return downstreamConnectionManager_ ? - downstreamConnectionManager_->getNumConnections() : 0; - } - - /** - * Access the Acceptor's event base. - */ - virtual EventBase* getEventBase() const { return base_; } - - /** - * Access the Acceptor's downstream (client-side) ConnectionManager - */ - virtual folly::wangle::ConnectionManager* getConnectionManager() { - return downstreamConnectionManager_.get(); - } - - /** - * Invoked when a new ManagedConnection is created. - * - * This allows the Acceptor to track the outstanding connections, - * for tracking timeouts and for ensuring that all connections have been - * drained on shutdown. - */ - void addConnection(folly::wangle::ManagedConnection* connection); - - /** - * Get this acceptor's current state. - */ - State getState() const { - return state_; - } - - /** - * Get the current connection timeout. - */ - std::chrono::milliseconds getConnTimeout() const; - - /** - * Returns the name of this VIP. - * - * Will return an empty string if no name has been configured. - */ - const std::string& getName() const { - return accConfig_.name; - } - - /** - * Force the acceptor to drop all connections and stop processing. - * - * This function may be called from any thread. The acceptor will not - * necessarily stop before this function returns: the stop will be scheduled - * to run in the acceptor's thread. - */ - virtual void forceStop(); - - bool isSSL() const { return accConfig_.isSSL(); } - - const ServerSocketConfig& getConfig() const { return accConfig_; } - - static uint64_t getTotalNumPendingSSLConns() { - return totalNumPendingSSLConns_.load(); - } - - /** - * Called right when the TCP connection has been accepted, before processing - * the first HTTP bytes (HTTP) or the SSL handshake (HTTPS) - */ - virtual void onDoneAcceptingConnection( - int fd, - const SocketAddress& clientAddr, - std::chrono::steady_clock::time_point acceptTime - ) noexcept; - - /** - * Begins either processing HTTP bytes (HTTP) or the SSL handshake (HTTPS) - */ - void processEstablishedConnection( - int fd, - const SocketAddress& clientAddr, - std::chrono::steady_clock::time_point acceptTime - ) noexcept; - - /** - * Drains all open connections of their outstanding transactions. When - * a connection's transaction count reaches zero, the connection closes. - */ - void drainAllConnections(); - - protected: - friend class AcceptorHandshakeHelper; - - /** - * Our event loop. - * - * Probably needs to be used to pass to a ManagedConnection - * implementation. Also visible in case a subclass wishes to do additional - * things w/ the event loop (e.g. in attach()). - */ - EventBase* base_{nullptr}; - - virtual uint64_t getConnectionCountForLoadShedding(void) const { return 0; } - - /** - * Hook for subclasses to drop newly accepted connections prior - * to handshaking. - */ - virtual bool canAccept(const folly::SocketAddress&); - - /** - * Invoked when a new connection is created. This is where application starts - * processing a new downstream connection. - * - * NOTE: Application should add the new connection to - * downstreamConnectionManager so that it can be garbage collected after - * certain period of idleness. - * - * @param sock the socket connected to the client - * @param address the address of the client - * @param nextProtocolName the name of the L6 or L7 protocol to be - * spoken on the connection, if known (e.g., - * from TLS NPN during secure connection setup), - * or an empty string if unknown - */ - virtual void onNewConnection( - AsyncSocket::UniquePtr sock, - const folly::SocketAddress* address, - const std::string& nextProtocolName, - const TransportInfo& tinfo) = 0; - - virtual AsyncSocket::UniquePtr makeNewAsyncSocket(EventBase* base, int fd) { - return AsyncSocket::UniquePtr(new AsyncSocket(base, fd)); - } - - virtual AsyncSSLSocket::UniquePtr makeNewAsyncSSLSocket( - const std::shared_ptr& ctx, EventBase* base, int fd) { - return AsyncSSLSocket::UniquePtr(new AsyncSSLSocket(ctx, base, fd)); - } - - /** - * Hook for subclasses to record stats about SSL connection establishment. - */ - virtual void updateSSLStats( - const AsyncSSLSocket* sock, - std::chrono::milliseconds acceptLatency, - SSLErrorEnum error) noexcept {} - - /** - * Drop all connections. - * - * forceStop() schedules dropAllConnections() to be called in the acceptor's - * thread. - */ - void dropAllConnections(); - - protected: - - /** - * onConnectionsDrained() will be called once all connections have been - * drained while the acceptor is stopping. - * - * Subclasses can override this method to perform any subclass-specific - * cleanup. - */ - virtual void onConnectionsDrained() {} - - // AsyncServerSocket::AcceptCallback methods - void connectionAccepted(int fd, - const folly::SocketAddress& clientAddr) - noexcept; - void acceptError(const std::exception& ex) noexcept; - void acceptStopped() noexcept; - - // ConnectionManager::Callback methods - void onEmpty(const folly::wangle::ConnectionManager& cm); - void onConnectionAdded(const folly::wangle::ConnectionManager& cm) {} - void onConnectionRemoved(const folly::wangle::ConnectionManager& cm) {} - - /** - * Process a connection that is to ready to receive L7 traffic. - * This method is called immediately upon accept for plaintext - * connections and upon completion of SSL handshaking or resumption - * for SSL connections. - */ - void connectionReady( - AsyncSocket::UniquePtr sock, - const folly::SocketAddress& clientAddr, - const std::string& nextProtocolName, - TransportInfo& tinfo); - - const LoadShedConfiguration& getLoadShedConfiguration() const { - return loadShedConfig_; - } - - protected: - const ServerSocketConfig accConfig_; - void setLoadShedConfig(const LoadShedConfiguration& from, - IConnectionCounter* counter); - - /** - * Socket options to apply to the client socket - */ - AsyncSocket::OptionMap socketOptions_; - - std::unique_ptr sslCtxManager_; - - /** - * Whether we want to enable client hello parsing in the handshake helper - * to get list of supported client ciphers. - */ - bool parseClientHello_{false}; - - folly::wangle::ConnectionManager::UniquePtr downstreamConnectionManager_; - - private: - - // Forbidden copy constructor and assignment opererator - Acceptor(Acceptor const &) = delete; - Acceptor& operator=(Acceptor const &) = delete; - - /** - * Wrapper for connectionReady() that decrements the count of - * pending SSL connections. - */ - void sslConnectionReady(AsyncSocket::UniquePtr sock, - const folly::SocketAddress& clientAddr, - const std::string& nextProtocol, - TransportInfo& tinfo); - - /** - * Notification callback for SSL handshake failures. - */ - void sslConnectionError(); - - void checkDrained(); - - State state_{State::kInit}; - uint64_t numPendingSSLConns_{0}; - - static std::atomic totalNumPendingSSLConns_; - - bool forceShutdownInProgress_{false}; - LoadShedConfiguration loadShedConfig_; - IConnectionCounter* connectionCounter_{nullptr}; - std::shared_ptr cacheProvider_; -}; - -class AcceptorFactory { - public: - virtual std::shared_ptr newAcceptor() = 0; - virtual ~AcceptorFactory() = default; -}; - -} // namespace diff --git a/folly/experimental/wangle/acceptor/ConnectionCounter.h b/folly/experimental/wangle/acceptor/ConnectionCounter.h deleted file mode 100644 index bf891bb2..00000000 --- a/folly/experimental/wangle/acceptor/ConnectionCounter.h +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -namespace folly { - -class IConnectionCounter { - public: - virtual uint64_t getNumConnections() const = 0; - - /** - * Get the maximum number of non-whitelisted client-side connections - * across all Acceptors managed by this. A value - * of zero means "unlimited." - */ - virtual uint64_t getMaxConnections() const = 0; - - /** - * Increment the count of client-side connections. - */ - virtual void onConnectionAdded() = 0; - - /** - * Decrement the count of client-side connections. - */ - virtual void onConnectionRemoved() = 0; - virtual ~IConnectionCounter() {} -}; - -class SimpleConnectionCounter: public IConnectionCounter { - public: - uint64_t getNumConnections() const override { return numConnections_; } - uint64_t getMaxConnections() const override { return maxConnections_; } - void setMaxConnections(uint64_t maxConnections) { - maxConnections_ = maxConnections; - } - - void onConnectionAdded() override { numConnections_++; } - void onConnectionRemoved() override { numConnections_--; } - virtual ~SimpleConnectionCounter() {} - - protected: - uint64_t maxConnections_{0}; - uint64_t numConnections_{0}; -}; - -} diff --git a/folly/experimental/wangle/acceptor/DomainNameMisc.h b/folly/experimental/wangle/acceptor/DomainNameMisc.h deleted file mode 100644 index 41c4c741..00000000 --- a/folly/experimental/wangle/acceptor/DomainNameMisc.h +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -#include - -namespace folly { - -struct dn_char_traits : public std::char_traits { - static bool eq(char c1, char c2) { - return ::tolower(c1) == ::tolower(c2); - } - - static bool ne(char c1, char c2) { - return ::tolower(c1) != ::tolower(c2); - } - - static bool lt(char c1, char c2) { - return ::tolower(c1) < ::tolower(c2); - } - - static int compare(const char* s1, const char* s2, size_t n) { - while (n--) { - if(::tolower(*s1) < ::tolower(*s2) ) { - return -1; - } - if(::tolower(*s1) > ::tolower(*s2) ) { - return 1; - } - ++s1; - ++s2; - } - return 0; - } - - static const char* find(const char* s, size_t n, char a) { - char la = ::tolower(a); - while (n--) { - if(::tolower(*s) == la) { - return s; - } else { - ++s; - } - } - return nullptr; - } -}; - -// Case insensitive string -typedef std::basic_string DNString; - -struct DNStringHash : public std::hash { - size_t operator()(const DNString& s) const noexcept { - size_t h = static_cast(0xc70f6907UL); - const char* d = s.data(); - for (size_t i = 0; i < s.length(); ++i) { - char a = ::tolower(*d++); - h = std::_Hash_impl::hash(&a, sizeof(a), h); - } - return h; - } -}; - -} // namespace diff --git a/folly/experimental/wangle/acceptor/LoadShedConfiguration.cpp b/folly/experimental/wangle/acceptor/LoadShedConfiguration.cpp deleted file mode 100644 index e08e71b7..00000000 --- a/folly/experimental/wangle/acceptor/LoadShedConfiguration.cpp +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#include - -#include -#include - -using std::string; - -namespace folly { - -void LoadShedConfiguration::addWhitelistAddr(folly::StringPiece input) { - auto addr = input.str(); - size_t separator = addr.find_first_of('/'); - if (separator == string::npos) { - whitelistAddrs_.insert(SocketAddress(addr, 0)); - } else { - unsigned prefixLen = folly::to(addr.substr(separator + 1)); - addr.erase(separator); - whitelistNetworks_.insert(NetworkAddress(SocketAddress(addr, 0), prefixLen)); - } -} - -bool LoadShedConfiguration::isWhitelisted(const SocketAddress& address) const { - if (whitelistAddrs_.find(address) != whitelistAddrs_.end()) { - return true; - } - for (auto& network : whitelistNetworks_) { - if (network.contains(address)) { - return true; - } - } - return false; -} - -} diff --git a/folly/experimental/wangle/acceptor/LoadShedConfiguration.h b/folly/experimental/wangle/acceptor/LoadShedConfiguration.h deleted file mode 100644 index 3c70b6ed..00000000 --- a/folly/experimental/wangle/acceptor/LoadShedConfiguration.h +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace folly { - -/** - * Class that holds an LoadShed configuration for a service - */ -class LoadShedConfiguration { - public: - - // Comparison function for SocketAddress that disregards the port - struct AddressOnlyCompare { - bool operator()( - const SocketAddress& addr1, - const SocketAddress& addr2) const { - return addr1.getIPAddress() < addr2.getIPAddress(); - } - }; - - typedef std::set AddressSet; - typedef std::set NetworkSet; - - LoadShedConfiguration() {} - - ~LoadShedConfiguration() {} - - void addWhitelistAddr(folly::StringPiece); - - /** - * Set/get the set of IPs that should be whitelisted through even when we're - * trying to shed load. - */ - void setWhitelistAddrs(const AddressSet& addrs) { whitelistAddrs_ = addrs; } - const AddressSet& getWhitelistAddrs() const { return whitelistAddrs_; } - - /** - * Set/get the set of networks that should be whitelisted through even - * when we're trying to shed load. - */ - void setWhitelistNetworks(const NetworkSet& networks) { - whitelistNetworks_ = networks; - } - const NetworkSet& getWhitelistNetworks() const { return whitelistNetworks_; } - - /** - * Set/get the maximum number of downstream connections across all VIPs. - */ - void setMaxConnections(uint64_t maxConns) { maxConnections_ = maxConns; } - uint64_t getMaxConnections() const { return maxConnections_; } - - /** - * Set/get the maximum cpu usage. - */ - void setMaxMemUsage(double max) { - CHECK(max >= 0); - CHECK(max <= 1); - maxMemUsage_ = max; - } - double getMaxMemUsage() const { return maxMemUsage_; } - - /** - * Set/get the maximum memory usage. - */ - void setMaxCpuUsage(double max) { - CHECK(max >= 0); - CHECK(max <= 1); - maxCpuUsage_ = max; - } - double getMaxCpuUsage() const { return maxCpuUsage_; } - - void setLoadUpdatePeriod(std::chrono::milliseconds period) { - period_ = period; - } - std::chrono::milliseconds getLoadUpdatePeriod() const { return period_; } - - bool isWhitelisted(const SocketAddress& addr) const; - - private: - - AddressSet whitelistAddrs_; - NetworkSet whitelistNetworks_; - uint64_t maxConnections_{0}; - double maxMemUsage_; - double maxCpuUsage_; - std::chrono::milliseconds period_; -}; - -} diff --git a/folly/experimental/wangle/acceptor/NetworkAddress.h b/folly/experimental/wangle/acceptor/NetworkAddress.h deleted file mode 100644 index 36980371..00000000 --- a/folly/experimental/wangle/acceptor/NetworkAddress.h +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -#include - -namespace folly { - -/** - * A simple wrapper around SocketAddress that represents - * a network in CIDR notation - */ -class NetworkAddress { -public: - /** - * Create a NetworkAddress for an addr/prefixLen - * @param addr IPv4 or IPv6 address of the network - * @param prefixLen Prefix length, in bits - */ - NetworkAddress(const folly::SocketAddress& addr, - unsigned prefixLen): - addr_(addr), prefixLen_(prefixLen) {} - - /** Get the network address */ - const folly::SocketAddress& getAddress() const { - return addr_; - } - - /** Get the prefix length in bits */ - unsigned getPrefixLength() const { return prefixLen_; } - - /** Check whether a given address lies within the network */ - bool contains(const folly::SocketAddress& addr) const { - return addr_.prefixMatch(addr, prefixLen_); - } - - /** Comparison operator to enable use in ordered collections */ - bool operator<(const NetworkAddress& other) const { - if (addr_ < other.addr_) { - return true; - } else if (other.addr_ < addr_) { - return false; - } else { - return (prefixLen_ < other.prefixLen_); - } - } - -private: - folly::SocketAddress addr_; - unsigned prefixLen_; -}; - -} // namespace diff --git a/folly/experimental/wangle/acceptor/ServerSocketConfig.h b/folly/experimental/wangle/acceptor/ServerSocketConfig.h deleted file mode 100644 index 14dc42cb..00000000 --- a/folly/experimental/wangle/acceptor/ServerSocketConfig.h +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace folly { - -/** - * Configuration for a single Acceptor. - * - * This configures not only accept behavior, but also some types of SSL - * behavior that may make sense to configure on a per-VIP basis (e.g. which - * cert(s) we use, etc). - */ -struct ServerSocketConfig { - ServerSocketConfig() { - // generate a single random current seed - uint8_t seed[32]; - folly::Random::secureRandom(seed, sizeof(seed)); - initialTicketSeeds.currentSeeds.push_back( - SSLUtil::hexlify(std::string((char *)seed, sizeof(seed)))); - } - - bool isSSL() const { return !(sslContextConfigs.empty()); } - - /** - * Set/get the socket options to apply on all downstream connections. - */ - void setSocketOptions( - const AsyncSocket::OptionMap& opts) { - socketOptions_ = filterIPSocketOptions(opts, bindAddress.getFamily()); - } - AsyncSocket::OptionMap& - getSocketOptions() { - return socketOptions_; - } - const AsyncSocket::OptionMap& - getSocketOptions() const { - return socketOptions_; - } - - bool hasExternalPrivateKey() const { - for (const auto& cfg : sslContextConfigs) { - if (!cfg.isLocalPrivateKey) { - return true; - } - } - return false; - } - - /** - * The name of this acceptor; used for stats/reporting purposes. - */ - std::string name; - - /** - * The depth of the accept queue backlog. - */ - uint32_t acceptBacklog{1024}; - - /** - * The number of milliseconds a connection can be idle before we close it. - */ - std::chrono::milliseconds connectionIdleTimeout{600000}; - - /** - * The address to bind to. - */ - SocketAddress bindAddress; - - /** - * Options for controlling the SSL cache. - */ - SSLCacheOptions sslCacheOptions{std::chrono::seconds(600), 20480, 200}; - - /** - * The initial TLS ticket seeds. - */ - TLSTicketKeySeeds initialTicketSeeds; - - /** - * The configs for all the SSL_CTX for use by this Acceptor. - */ - std::vector sslContextConfigs; - - /** - * Determines if the Acceptor does strict checking when loading the SSL - * contexts. - */ - bool strictSSL{true}; - - /** - * Maximum number of concurrent pending SSL handshakes - */ - uint32_t maxConcurrentSSLHandshakes{30720}; - - private: - AsyncSocket::OptionMap socketOptions_; -}; - -} // folly diff --git a/folly/experimental/wangle/acceptor/SocketOptions.cpp b/folly/experimental/wangle/acceptor/SocketOptions.cpp deleted file mode 100644 index c7c82b88..00000000 --- a/folly/experimental/wangle/acceptor/SocketOptions.cpp +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#include - -#include -#include - -namespace folly { - -AsyncSocket::OptionMap filterIPSocketOptions( - const AsyncSocket::OptionMap& allOptions, - const int addrFamily) { - AsyncSocket::OptionMap opts; - int exclude; - if (addrFamily == AF_INET) { - exclude = IPPROTO_IPV6; - } else if (addrFamily == AF_INET6) { - exclude = IPPROTO_IP; - } else { - LOG(FATAL) << "Address family " << addrFamily << " was not IPv4 or IPv6"; - return opts; - } - for (const auto& opt: allOptions) { - if (opt.first.level != exclude) { - opts[opt.first] = opt.second; - } - } - return opts; -} - -} diff --git a/folly/experimental/wangle/acceptor/SocketOptions.h b/folly/experimental/wangle/acceptor/SocketOptions.h deleted file mode 100644 index 37ba3711..00000000 --- a/folly/experimental/wangle/acceptor/SocketOptions.h +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -#include - -namespace folly { - -/** - * Returns a copy of the socket options excluding options with the given - * level. - */ -AsyncSocket::OptionMap filterIPSocketOptions( - const AsyncSocket::OptionMap& allOptions, - const int addrFamily); - -} diff --git a/folly/experimental/wangle/acceptor/TransportInfo.cpp b/folly/experimental/wangle/acceptor/TransportInfo.cpp deleted file mode 100644 index 02de719c..00000000 --- a/folly/experimental/wangle/acceptor/TransportInfo.cpp +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#include - -#include -#include -#include - -using std::chrono::microseconds; -using std::map; -using std::string; - -namespace folly { - -bool TransportInfo::initWithSocket(const AsyncSocket* sock) { -#if defined(__linux__) || defined(__FreeBSD__) - if (!TransportInfo::readTcpInfo(&tcpinfo, sock)) { - tcpinfoErrno = errno; - return false; - } - rtt = microseconds(tcpinfo.tcpi_rtt); - validTcpinfo = true; -#else - tcpinfoErrno = EINVAL; - rtt = microseconds(-1); -#endif - return true; -} - -int64_t TransportInfo::readRTT(const AsyncSocket* sock) { -#if defined(__linux__) || defined(__FreeBSD__) - struct tcp_info tcpinfo; - if (!TransportInfo::readTcpInfo(&tcpinfo, sock)) { - return -1; - } - return tcpinfo.tcpi_rtt; -#else - return -1; -#endif -} - -#if defined(__linux__) || defined(__FreeBSD__) -bool TransportInfo::readTcpInfo(struct tcp_info* tcpinfo, - const AsyncSocket* sock) { - socklen_t len = sizeof(struct tcp_info); - if (!sock) { - return false; - } - if (getsockopt(sock->getFd(), IPPROTO_TCP, - TCP_INFO, (void*) tcpinfo, &len) < 0) { - VLOG(4) << "Error calling getsockopt(): " << strerror(errno); - return false; - } - return true; -} -#endif - -} // folly diff --git a/folly/experimental/wangle/acceptor/TransportInfo.h b/folly/experimental/wangle/acceptor/TransportInfo.h deleted file mode 100644 index e108c466..00000000 --- a/folly/experimental/wangle/acceptor/TransportInfo.h +++ /dev/null @@ -1,292 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -#include - -#include -#include -#include - -namespace folly { -class AsyncSocket; - -/** - * A structure that encapsulates byte counters related to the HTTP headers. - */ -struct HTTPHeaderSize { - /** - * The number of bytes used to represent the header after compression or - * before decompression. If header compression is not supported, the value - * is set to 0. - */ - size_t compressed{0}; - - /** - * The number of bytes used to represent the serialized header before - * compression or after decompression, in plain-text format. - */ - size_t uncompressed{0}; -}; - -struct TransportInfo { - /* - * timestamp of when the connection handshake was completed - */ - std::chrono::steady_clock::time_point acceptTime{}; - - /* - * connection RTT (Round-Trip Time) - */ - std::chrono::microseconds rtt{0}; - -#if defined(__linux__) || defined(__FreeBSD__) - /* - * TCP information as fetched from getsockopt(2) - */ - tcp_info tcpinfo { -#if __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 17 - 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 // 32 -#else - 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 // 29 -#endif // __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 17 - }; -#endif // defined(__linux__) || defined(__FreeBSD__) - - /* - * time for setting the connection, from the moment in was accepted until it - * is established. - */ - std::chrono::milliseconds setupTime{0}; - - /* - * time for setting up the SSL connection or SSL handshake - */ - std::chrono::milliseconds sslSetupTime{0}; - - /* - * The name of the SSL ciphersuite used by the transaction's - * transport. Returns null if the transport is not SSL. - */ - const char* sslCipher{nullptr}; - - /* - * The SSL server name used by the transaction's - * transport. Returns null if the transport is not SSL. - */ - const char* sslServerName{nullptr}; - - /* - * list of ciphers sent by the client - */ - std::string sslClientCiphers{}; - - /* - * list of compression methods sent by the client - */ - std::string sslClientComprMethods{}; - - /* - * list of TLS extensions sent by the client - */ - std::string sslClientExts{}; - - /* - * hash of all the SSL parameters sent by the client - */ - std::string sslSignature{}; - - /* - * list of ciphers supported by the server - */ - std::string sslServerCiphers{}; - - /* - * guessed "(os) (browser)" based on SSL Signature - */ - std::string guessedUserAgent{}; - - /** - * The result of SSL NPN negotiation. - */ - std::string sslNextProtocol{}; - - /* - * total number of bytes sent over the connection - */ - int64_t totalBytes{0}; - - /** - * header bytes read - */ - HTTPHeaderSize ingressHeader; - - /* - * header bytes written - */ - HTTPHeaderSize egressHeader; - - /* - * Here is how the timeToXXXByte variables are planned out: - * 1. All timeToXXXByte variables are measuring the ByteEvent from reqStart_ - * 2. You can get the timing between two ByteEvents by calculating their - * differences. For example: - * timeToLastBodyByteAck - timeToFirstByte - * => Total time to deliver the body - * 3. The calculation in point (2) is typically done outside acceptor - * - * Future plan: - * We should log the timestamps (TimePoints) and allow - * the consumer to calculate the latency whatever it - * wants instead of calculating them in wangle, for the sake of flexibility. - * For example: - * 1. TimePoint reqStartTimestamp; - * 2. TimePoint firstHeaderByteSentTimestamp; - * 3. TimePoint firstBodyByteTimestamp; - * 3. TimePoint lastBodyByteTimestamp; - * 4. TimePoint lastBodyByteAckTimestamp; - */ - - /* - * time to first header byte written to the kernel send buffer - * NOTE: It is not 100% accurate since TAsyncSocket does not do - * do callback on partial write. - */ - int32_t timeToFirstHeaderByte{-1}; - - /* - * time to first body byte written to the kernel send buffer - */ - int32_t timeToFirstByte{-1}; - - /* - * time to last body byte written to the kernel send buffer - */ - int32_t timeToLastByte{-1}; - - /* - * time to TCP Ack received for the last written body byte - */ - int32_t timeToLastBodyByteAck{-1}; - - /* - * time it took the client to ACK the last byte, from the moment when the - * kernel sent the last byte to the client and until it received the ACK - * for that byte - */ - int32_t lastByteAckLatency{-1}; - - /* - * time spent inside wangle - */ - int32_t proxyLatency{-1}; - - /* - * time between connection accepted and client message headers completed - */ - int32_t clientLatency{-1}; - - /* - * latency for communication with the server - */ - int32_t serverLatency{-1}; - - /* - * time used to get a usable connection. - */ - int32_t connectLatency{-1}; - - /* - * body bytes written - */ - uint32_t egressBodySize{0}; - - /* - * value of errno in case of getsockopt() error - */ - int tcpinfoErrno{0}; - - /* - * bytes read & written during SSL Setup - */ - uint32_t sslSetupBytesWritten{0}; - uint32_t sslSetupBytesRead{0}; - - /** - * SSL error detail - */ - uint32_t sslError{0}; - - /** - * body bytes read - */ - uint32_t ingressBodySize{0}; - - /* - * The SSL version used by the transaction's transport, in - * OpenSSL's format: 4 bits for the major version, followed by 4 bits - * for the minor version. Returns zero for non-SSL. - */ - uint16_t sslVersion{0}; - - /* - * The SSL certificate size. - */ - uint16_t sslCertSize{0}; - - /** - * response status code - */ - uint16_t statusCode{0}; - - /* - * The SSL mode for the transaction's transport: new session, - * resumed session, or neither (non-SSL). - */ - SSLResumeEnum sslResume{SSLResumeEnum::NA}; - - /* - * true if the tcpinfo was successfully read from the kernel - */ - bool validTcpinfo{false}; - - /* - * true if the connection is SSL, false otherwise - */ - bool ssl{false}; - - /* - * get the RTT value in milliseconds - */ - std::chrono::milliseconds getRttMs() const { - return std::chrono::duration_cast(rtt); - } - - /* - * initialize the fields related with tcp_info - */ - bool initWithSocket(const AsyncSocket* sock); - - /* - * Get the kernel's estimate of round-trip time (RTT) to the transport's peer - * in microseconds. Returns -1 on error. - */ - static int64_t readRTT(const AsyncSocket* sock); - -#if defined(__linux__) || defined(__FreeBSD__) - /* - * perform the getsockopt(2) syscall to fetch TCP info for a given socket - */ - static bool readTcpInfo(struct tcp_info* tcpinfo, - const AsyncSocket* sock); -#endif -}; - -} // folly diff --git a/folly/experimental/wangle/bootstrap/BootstrapTest.cpp b/folly/experimental/wangle/bootstrap/BootstrapTest.cpp deleted file mode 100644 index 25bb75ea..00000000 --- a/folly/experimental/wangle/bootstrap/BootstrapTest.cpp +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "folly/experimental/wangle/bootstrap/ServerBootstrap.h" -#include "folly/experimental/wangle/bootstrap/ClientBootstrap.h" -#include "folly/experimental/wangle/channel/ChannelHandler.h" - -#include -#include - -using namespace folly::wangle; -using namespace folly; - -typedef ChannelPipeline> Pipeline; - -class TestServer : public ServerBootstrap { - Pipeline* newPipeline(std::shared_ptr) { - return nullptr; - } -}; - -class TestClient : public ClientBootstrap { - Pipeline* newPipeline(std::shared_ptr sock) { - CHECK(sock->good()); - - // We probably aren't connected immedately, check after a small delay - EventBaseManager::get()->getEventBase()->runAfterDelay([sock](){ - CHECK(sock->readable()); - }, 100); - return nullptr; - } -}; - -class TestPipelineFactory : public PipelineFactory { - public: - Pipeline* newPipeline(std::shared_ptr sock) { - pipelines++; - return new Pipeline(); - } - std::atomic pipelines{0}; -}; - -TEST(Bootstrap, Basic) { - TestServer server; - TestClient client; -} - -TEST(Bootstrap, ServerWithPipeline) { - TestServer server; - server.childPipeline(std::make_shared()); - server.bind(0); - server.stop(); -} - -TEST(Bootstrap, ClientServerTest) { - TestServer server; - auto factory = std::make_shared(); - server.childPipeline(factory); - server.bind(0); - auto base = EventBaseManager::get()->getEventBase(); - - SocketAddress address; - server.getSockets()[0]->getAddress(&address); - - TestClient client; - client.connect(address); - base->loop(); - server.stop(); - - CHECK(factory->pipelines == 1); -} - -TEST(Bootstrap, ClientConnectionManagerTest) { - // Create a single IO thread, and verify that - // client connections are pooled properly - - TestServer server; - auto factory = std::make_shared(); - server.childPipeline(factory); - server.group(std::make_shared(1)); - server.bind(0); - auto base = EventBaseManager::get()->getEventBase(); - - SocketAddress address; - server.getSockets()[0]->getAddress(&address); - - TestClient client; - client.connect(address); - - TestClient client2; - client2.connect(address); - - base->loop(); - server.stop(); - - CHECK(factory->pipelines == 2); -} - -TEST(Bootstrap, ServerAcceptGroupTest) { - // Verify that server is using the accept IO group - - TestServer server; - auto factory = std::make_shared(); - server.childPipeline(factory); - server.group(std::make_shared(1), nullptr); - server.bind(0); - - SocketAddress address; - server.getSockets()[0]->getAddress(&address); - - boost::barrier barrier(2); - auto thread = std::thread([&](){ - TestClient client; - client.connect(address); - EventBaseManager::get()->getEventBase()->loop(); - barrier.wait(); - }); - barrier.wait(); - server.stop(); - thread.join(); - - CHECK(factory->pipelines == 1); -} - -TEST(Bootstrap, ServerAcceptGroup2Test) { - // Verify that server is using the accept IO group - - // Check if reuse port is supported, if not, don't run this test - try { - EventBase base; - auto serverSocket = AsyncServerSocket::newSocket(&base); - serverSocket->bind(0); - serverSocket->listen(0); - serverSocket->startAccepting(); - serverSocket->setReusePortEnabled(true); - serverSocket->stopAccepting(); - } catch(...) { - LOG(INFO) << "Reuse port probably not supported"; - return; - } - - TestServer server; - auto factory = std::make_shared(); - server.childPipeline(factory); - server.group(std::make_shared(4), nullptr); - server.bind(0); - - SocketAddress address; - server.getSockets()[0]->getAddress(&address); - - TestClient client; - client.connect(address); - EventBaseManager::get()->getEventBase()->loop(); - - server.stop(); - - CHECK(factory->pipelines == 1); -} diff --git a/folly/experimental/wangle/bootstrap/ClientBootstrap.h b/folly/experimental/wangle/bootstrap/ClientBootstrap.h deleted file mode 100644 index dadbf5c5..00000000 --- a/folly/experimental/wangle/bootstrap/ClientBootstrap.h +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -namespace folly { - -/* - * A thin wrapper around ChannelPipeline and AsyncSocket to match - * ServerBootstrap. On connect() a new pipeline is created. - */ -template -class ClientBootstrap { - public: - ClientBootstrap() { - } - ClientBootstrap* bind(int port) { - port_ = port; - return this; - } - ClientBootstrap* connect(SocketAddress address) { - pipeline_.reset( - newPipeline( - AsyncSocket::newSocket(EventBaseManager::get()->getEventBase(), address) - )); - return this; - } - - virtual ~ClientBootstrap() {} - - protected: - std::unique_ptr pipeline_; - - int port_; - - virtual Pipeline* newPipeline(std::shared_ptr socket) = 0; -}; - -} // namespace diff --git a/folly/experimental/wangle/bootstrap/ServerBootstrap-inl.h b/folly/experimental/wangle/bootstrap/ServerBootstrap-inl.h deleted file mode 100644 index 8db45409..00000000 --- a/folly/experimental/wangle/bootstrap/ServerBootstrap-inl.h +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include - -namespace folly { - -template -class ServerAcceptor : public Acceptor { - typedef std::unique_ptr PipelinePtr; - - class ServerConnection : public wangle::ManagedConnection { - public: - explicit ServerConnection(PipelinePtr pipeline) - : pipeline_(std::move(pipeline)) {} - - ~ServerConnection() { - } - - void timeoutExpired() noexcept { - } - - void describe(std::ostream& os) const {} - bool isBusy() const { - return false; - } - void notifyPendingShutdown() {} - void closeWhenIdle() {} - void dropConnection() {} - void dumpConnectionState(uint8_t loglevel) {} - private: - PipelinePtr pipeline_; - }; - - public: - explicit ServerAcceptor( - std::shared_ptr> pipelineFactory) - : Acceptor(ServerSocketConfig()) - , pipelineFactory_(pipelineFactory) { - Acceptor::init(nullptr, &base_); - } - - /* See Acceptor::onNewConnection for details */ - void onNewConnection( - AsyncSocket::UniquePtr transport, const SocketAddress* address, - const std::string& nextProtocolName, const TransportInfo& tinfo) { - - std::unique_ptr - pipeline(pipelineFactory_->newPipeline( - std::shared_ptr( - transport.release(), - folly::DelayedDestruction::Destructor()))); - auto connection = new ServerConnection(std::move(pipeline)); - Acceptor::addConnection(connection); - } - - ~ServerAcceptor() { - Acceptor::dropAllConnections(); - } - - private: - EventBase base_; - - std::shared_ptr> pipelineFactory_; -}; - -template -class ServerAcceptorFactory : public AcceptorFactory { - public: - explicit ServerAcceptorFactory( - std::shared_ptr> factory) - : factory_(factory) {} - - std::shared_ptr newAcceptor() { - return std::make_shared>(factory_); - } - private: - std::shared_ptr> factory_; -}; - -class ServerWorkerFactory : public folly::wangle::ThreadFactory { - public: - explicit ServerWorkerFactory(std::shared_ptr acceptorFactory) - : internalFactory_( - std::make_shared("BootstrapWorker")) - , acceptorFactory_(acceptorFactory) - {} - virtual std::thread newThread(folly::Func&& func) override; - - void setInternalFactory( - std::shared_ptr internalFactory); - void setNamePrefix(folly::StringPiece prefix); - - template - void forEachWorker(F&& f); - - private: - std::shared_ptr internalFactory_; - folly::RWSpinLock workersLock_; - std::map> workers_; - int32_t nextWorkerId_{0}; - - std::shared_ptr acceptorFactory_; -}; - -template -void ServerWorkerFactory::forEachWorker(F&& f) { - folly::RWSpinLock::ReadHolder guard(workersLock_); - for (const auto& kv : workers_) { - f(kv.second.get()); - } -} - -} // namespace diff --git a/folly/experimental/wangle/bootstrap/ServerBootstrap.cpp b/folly/experimental/wangle/bootstrap/ServerBootstrap.cpp deleted file mode 100644 index 7a75452b..00000000 --- a/folly/experimental/wangle/bootstrap/ServerBootstrap.cpp +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include - -namespace folly { - -std::thread ServerWorkerFactory::newThread( - folly::Func&& func) { - auto id = nextWorkerId_++; - auto worker = acceptorFactory_->newAcceptor(); - { - folly::RWSpinLock::WriteHolder guard(workersLock_); - workers_.insert({id, worker}); - } - return internalFactory_->newThread([=](){ - EventBaseManager::get()->setEventBase(worker->getEventBase(), false); - func(); - EventBaseManager::get()->clearEventBase(); - - worker->drainAllConnections(); - { - folly::RWSpinLock::WriteHolder guard(workersLock_); - workers_.erase(id); - } - }); -} - -void ServerWorkerFactory::setInternalFactory( - std::shared_ptr internalFactory) { - CHECK(workers_.empty()); - internalFactory_ = internalFactory; -} - -void ServerWorkerFactory::setNamePrefix(folly::StringPiece prefix) { - CHECK(workers_.empty()); - internalFactory_->setNamePrefix(prefix); -} - -} // namespace diff --git a/folly/experimental/wangle/bootstrap/ServerBootstrap.h b/folly/experimental/wangle/bootstrap/ServerBootstrap.h deleted file mode 100644 index 85edb646..00000000 --- a/folly/experimental/wangle/bootstrap/ServerBootstrap.h +++ /dev/null @@ -1,238 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include - -namespace folly { - -/* - * ServerBootstrap is a parent class intended to set up a - * high-performance TCP accepting server. It will manage a pool of - * accepting threads, any number of accepting sockets, a pool of - * IO-worker threads, and connection pool for each IO thread for you. - * - * The output is given as a ChannelPipeline template: given a - * PipelineFactory, it will create a new pipeline for each connection, - * and your server can handle the incoming bytes. - * - * BACKWARDS COMPATIBLITY: for servers already taking a pool of - * Acceptor objects, an AcceptorFactory can be given directly instead - * of a pipeline factory. - */ -template -class ServerBootstrap { - public: - /* TODO(davejwatson) - * - * If there is any work to be done BEFORE handing the work to IO - * threads, this handler is where the pipeline to do it would be - * set. - * - * This could be used for things like logging, load balancing, or - * advanced load balancing on IO threads. Netty also provides this. - */ - ServerBootstrap* handler() { - return this; - } - - /* - * BACKWARDS COMPATIBILITY - an acceptor factory can be set. Your - * Acceptor is responsible for managing the connection pool. - * - * @param childHandler - acceptor factory to call for each IO thread - */ - ServerBootstrap* childHandler(std::shared_ptr childHandler) { - acceptorFactory_ = childHandler; - return this; - } - - /* - * Set a pipeline factory that will be called for each new connection - * - * @param factory pipeline factory to use for each new connection - */ - ServerBootstrap* childPipeline( - std::shared_ptr> factory) { - pipelineFactory_ = factory; - return this; - } - - /* - * Set the IO executor. If not set, a default one will be created - * with one thread per core. - * - * @param io_group - io executor to use for IO threads. - */ - ServerBootstrap* group( - std::shared_ptr io_group) { - return group(nullptr, io_group); - } - - /* - * Set the acceptor executor, and IO executor. - * - * If no acceptor executor is set, a single thread will be created for accepts - * If no IO executor is set, a default of one thread per core will be created - * - * @param group - acceptor executor to use for acceptor threads. - * @param io_group - io executor to use for IO threads. - */ - ServerBootstrap* group( - std::shared_ptr accept_group, - std::shared_ptr io_group) { - if (!accept_group) { - accept_group = std::make_shared( - 1, std::make_shared("Acceptor Thread")); - } - if (!io_group) { - io_group = std::make_shared( - 32, std::make_shared("IO Thread")); - } - auto factoryBase = io_group->getThreadFactory(); - CHECK(factoryBase); - auto factory = std::dynamic_pointer_cast( - factoryBase); - CHECK(factory); // Must be named thread factory - - CHECK(acceptorFactory_ || pipelineFactory_); - - if (acceptorFactory_) { - workerFactory_ = std::make_shared( - acceptorFactory_); - } else { - workerFactory_ = std::make_shared( - std::make_shared>(pipelineFactory_)); - } - workerFactory_->setInternalFactory(factory); - - acceptor_group_ = accept_group; - io_group_ = io_group; - - auto numThreads = io_group_->numThreads(); - io_group_->setNumThreads(0); - io_group_->setThreadFactory(workerFactory_); - io_group_->setNumThreads(numThreads); - - return this; - } - - /* - * Bind to a port and start listening. - * One of childPipeline or childHandler must be called before bind - * - * @param port Port to listen on - */ - void bind(int port) { - // TODO take existing socket - - if (!workerFactory_) { - group(nullptr); - } - - bool reusePort = false; - if (acceptor_group_->numThreads() > 1) { - reusePort = true; - } - - std::mutex sock_lock; - std::vector> new_sockets; - - auto startupFunc = [&](std::shared_ptr barrier){ - auto socket = folly::AsyncServerSocket::newSocket(); - sock_lock.lock(); - new_sockets.push_back(socket); - sock_lock.unlock(); - socket->setReusePortEnabled(reusePort); - socket->attachEventBase(EventBaseManager::get()->getEventBase()); - socket->bind(port); - // TODO Take ServerSocketConfig - socket->listen(1024); - socket->startAccepting(); - - if (port == 0) { - SocketAddress address; - socket->getAddress(&address); - port = address.getPort(); - } - - barrier->wait(); - }; - - auto bind0 = std::make_shared(2); - acceptor_group_->add(std::bind(startupFunc, bind0)); - bind0->wait(); - - auto barrier = std::make_shared(acceptor_group_->numThreads()); - for (int i = 1; i < acceptor_group_->numThreads(); i++) { - acceptor_group_->add(std::bind(startupFunc, barrier)); - } - barrier->wait(); - - // Startup all the threads - for(auto socket : new_sockets) { - workerFactory_->forEachWorker([this, socket](Acceptor* worker){ - socket->getEventBase()->runInEventBaseThread([this, worker, socket](){ - socket->addAcceptCallback(worker, worker->getEventBase()); - }); - }); - } - - for (auto& socket : new_sockets) { - sockets_.push_back(socket); - } - } - - /* - * Stop listening on all sockets. - */ - void stop() { - auto barrier = std::make_shared(sockets_.size() + 1); - for (auto socket : sockets_) { - socket->getEventBase()->runInEventBaseThread([barrier, socket]() { - socket->stopAccepting(); - socket->detachEventBase(); - barrier->wait(); - }); - } - barrier->wait(); - sockets_.clear(); - - acceptor_group_->join(); - io_group_->join(); - } - - /* - * Get the list of listening sockets - */ - std::vector>& - getSockets() { - return sockets_; - } - - private: - std::shared_ptr acceptor_group_; - std::shared_ptr io_group_; - - std::shared_ptr workerFactory_; - std::vector> sockets_; - - std::shared_ptr acceptorFactory_; - std::shared_ptr> pipelineFactory_; -}; - -} // namespace diff --git a/folly/experimental/wangle/channel/AsyncSocketHandler.h b/folly/experimental/wangle/channel/AsyncSocketHandler.h deleted file mode 100644 index 8d586d0f..00000000 --- a/folly/experimental/wangle/channel/AsyncSocketHandler.h +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace folly { namespace wangle { - -class AsyncSocketHandler - : public folly::wangle::BytesToBytesHandler, - public AsyncSocket::ReadCallback { - public: - explicit AsyncSocketHandler( - std::shared_ptr socket) - : socket_(std::move(socket)) {} - - AsyncSocketHandler(AsyncSocketHandler&&) = default; - - ~AsyncSocketHandler() { - if (socket_) { - detachReadCallback(); - } - } - - void attachReadCallback() { - socket_->setReadCB(socket_->good() ? this : nullptr); - } - - void detachReadCallback() { - if (socket_->getReadCallback() == this) { - socket_->setReadCB(nullptr); - } - } - - void attachEventBase(folly::EventBase* eventBase) { - if (eventBase && !socket_->getEventBase()) { - socket_->attachEventBase(eventBase); - } - } - - void detachEventBase() { - detachReadCallback(); - if (socket_->getEventBase()) { - socket_->detachEventBase(); - } - } - - void attachPipeline(Context* ctx) override { - CHECK(!ctx_); - ctx_ = ctx; - } - - folly::wangle::Future write( - Context* ctx, - std::unique_ptr buf) override { - if (UNLIKELY(!buf)) { - return folly::wangle::makeFuture(); - } - - if (!socket_->good()) { - VLOG(5) << "socket is closed in write()"; - return folly::wangle::makeFuture(AsyncSocketException( - AsyncSocketException::AsyncSocketExceptionType::NOT_OPEN, - "socket is closed in write()")); - } - - auto cb = new WriteCallback(); - auto future = cb->promise_.getFuture(); - socket_->writeChain(cb, std::move(buf), ctx->getWriteFlags()); - return future; - }; - - folly::wangle::Future close(Context* ctx) { - if (socket_) { - detachReadCallback(); - socket_->closeNow(); - } - return folly::wangle::makeFuture(); - } - - // Must override to avoid warnings about hidden overloaded virtual due to - // AsyncSocket::ReadCallback::readEOF() - void readEOF(Context* ctx) override { - ctx->fireReadEOF(); - } - - void getReadBuffer(void** bufReturn, size_t* lenReturn) override { - const auto readBufferSettings = ctx_->getReadBufferSettings(); - const auto ret = bufQueue_.preallocate( - readBufferSettings.first, - readBufferSettings.second); - *bufReturn = ret.first; - *lenReturn = ret.second; - } - - void readDataAvailable(size_t len) noexcept override { - bufQueue_.postallocate(len); - ctx_->fireRead(bufQueue_); - } - - void readEOF() noexcept override { - ctx_->fireReadEOF(); - } - - void readErr(const AsyncSocketException& ex) - noexcept override { - ctx_->fireReadException(make_exception_wrapper(ex)); - } - - private: - class WriteCallback : private AsyncSocket::WriteCallback { - void writeSuccess() noexcept override { - promise_.setValue(); - delete this; - } - - void writeErr(size_t bytesWritten, - const AsyncSocketException& ex) - noexcept override { - promise_.setException(ex); - delete this; - } - - private: - friend class AsyncSocketHandler; - folly::wangle::Promise promise_; - }; - - Context* ctx_{nullptr}; - folly::IOBufQueue bufQueue_; - std::shared_ptr socket_{nullptr}; -}; - -}} diff --git a/folly/experimental/wangle/channel/ChannelHandler.h b/folly/experimental/wangle/channel/ChannelHandler.h deleted file mode 100644 index e7fd3135..00000000 --- a/folly/experimental/wangle/channel/ChannelHandler.h +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -namespace folly { namespace wangle { - -template -class ChannelHandler { - public: - typedef Rin rin; - typedef Rout rout; - typedef Win win; - typedef Wout wout; - typedef ChannelHandlerContext Context; - virtual ~ChannelHandler() {} - - virtual void read(Context* ctx, Rin msg) = 0; - virtual void readEOF(Context* ctx) { - ctx->fireReadEOF(); - } - virtual void readException(Context* ctx, exception_wrapper e) { - ctx->fireReadException(std::move(e)); - } - - virtual Future write(Context* ctx, Win msg) = 0; - virtual Future close(Context* ctx) { - return ctx->fireClose(); - } - - virtual void attachPipeline(Context* ctx) {} - virtual void attachTransport(Context* ctx) {} - - virtual void detachPipeline(Context* ctx) {} - virtual void detachTransport(Context* ctx) {} - - /* - // Other sorts of things we might want, all shamelessly stolen from Netty - // inbound - virtual void exceptionCaught( - ChannelHandlerContext* ctx, - exception_wrapper e) {} - virtual void channelRegistered(ChannelHandlerContext* ctx) {} - virtual void channelUnregistered(ChannelHandlerContext* ctx) {} - virtual void channelActive(ChannelHandlerContext* ctx) {} - virtual void channelInactive(ChannelHandlerContext* ctx) {} - virtual void channelReadComplete(ChannelHandlerContext* ctx) {} - virtual void userEventTriggered(ChannelHandlerContext* ctx, void* evt) {} - virtual void channelWritabilityChanged(ChannelHandlerContext* ctx) {} - - // outbound - virtual Future bind( - ChannelHandlerContext* ctx, - SocketAddress localAddress) {} - virtual Future connect( - ChannelHandlerContext* ctx, - SocketAddress remoteAddress, SocketAddress localAddress) {} - virtual Future disconnect(ChannelHandlerContext* ctx) {} - virtual Future deregister(ChannelHandlerContext* ctx) {} - virtual Future read(ChannelHandlerContext* ctx) {} - virtual void flush(ChannelHandlerContext* ctx) {} - */ -}; - -template -class ChannelHandlerAdapter : public ChannelHandler { - public: - typedef typename ChannelHandler::Context Context; - - void read(Context* ctx, R msg) override { - ctx->fireRead(std::forward(msg)); - } - - Future write(Context* ctx, W msg) override { - return ctx->fireWrite(std::forward(msg)); - } -}; - -typedef ChannelHandlerAdapter> -BytesToBytesHandler; - -template -class ChannelHandlerPtr : public ChannelHandler< - typename Handler::rin, - typename Handler::rout, - typename Handler::win, - typename Handler::wout> { - public: - typedef typename std::conditional< - Shared, - std::shared_ptr, - Handler*>::type - HandlerPtr; - - typedef typename Handler::Context Context; - - explicit ChannelHandlerPtr(HandlerPtr handler) - : handler_(std::move(handler)) {} - - void setHandler(HandlerPtr handler) { - if (handler == handler_) { - return; - } - if (handler_ && ctx_) { - handler_->detachPipeline(ctx_); - } - handler_ = std::move(handler); - if (handler_ && ctx_) { - handler_->attachPipeline(ctx_); - if (ctx_->getTransport()) { - handler_->attachTransport(ctx_); - } - } - } - - void attachPipeline(Context* ctx) override { - ctx_ = ctx; - if (handler_) { - handler_->attachPipeline(ctx_); - } - } - - void attachTransport(Context* ctx) override { - ctx_ = ctx; - if (handler_) { - handler_->attachTransport(ctx_); - } - } - - void detachPipeline(Context* ctx) override { - ctx_ = ctx; - if (handler_) { - handler_->detachPipeline(ctx_); - } - } - - void detachTransport(Context* ctx) override { - ctx_ = ctx; - if (handler_) { - handler_->detachTransport(ctx_); - } - } - - void read(Context* ctx, typename Handler::rin msg) override { - DCHECK(handler_); - handler_->read(ctx, std::forward(msg)); - } - - void readEOF(Context* ctx) override { - DCHECK(handler_); - handler_->readEOF(ctx); - } - - void readException(Context* ctx, exception_wrapper e) override { - DCHECK(handler_); - handler_->readException(ctx, std::move(e)); - } - - Future write(Context* ctx, typename Handler::win msg) override { - DCHECK(handler_); - return handler_->write(ctx, std::forward(msg)); - } - - Future close(Context* ctx) override { - DCHECK(handler_); - return handler_->close(ctx); - } - - private: - Context* ctx_; - HandlerPtr handler_; -}; - -}} diff --git a/folly/experimental/wangle/channel/ChannelHandlerContext.h b/folly/experimental/wangle/channel/ChannelHandlerContext.h deleted file mode 100644 index 59ea3ae4..00000000 --- a/folly/experimental/wangle/channel/ChannelHandlerContext.h +++ /dev/null @@ -1,252 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -namespace folly { namespace wangle { - -template -class ChannelHandlerContext { - public: - virtual ~ChannelHandlerContext() {} - - virtual void fireRead(In msg) = 0; - virtual void fireReadEOF() = 0; - virtual void fireReadException(exception_wrapper e) = 0; - - virtual Future fireWrite(Out msg) = 0; - virtual Future fireClose() = 0; - - virtual std::shared_ptr getTransport() = 0; - - virtual void setWriteFlags(WriteFlags flags) = 0; - virtual WriteFlags getWriteFlags() = 0; - - virtual void setReadBufferSettings( - uint64_t minAvailable, - uint64_t allocationSize) = 0; - virtual std::pair getReadBufferSettings() = 0; - - /* TODO - template - virtual void addHandlerBefore(H&&) {} - template - virtual void addHandlerAfter(H&&) {} - template - virtual void replaceHandler(H&&) {} - virtual void removeHandler() {} - */ -}; - -class PipelineContext { - public: - virtual ~PipelineContext() {} - - virtual void attachTransport() = 0; - virtual void detachTransport() = 0; - - void link(PipelineContext* other) { - setNextIn(other); - other->setNextOut(this); - } - - protected: - virtual void setNextIn(PipelineContext* ctx) = 0; - virtual void setNextOut(PipelineContext* ctx) = 0; -}; - -template -class InboundChannelHandlerContext { - public: - virtual ~InboundChannelHandlerContext() {} - virtual void read(In msg) = 0; - virtual void readEOF() = 0; - virtual void readException(exception_wrapper e) = 0; -}; - -template -class OutboundChannelHandlerContext { - public: - virtual ~OutboundChannelHandlerContext() {} - virtual Future write(Out msg) = 0; - virtual Future close() = 0; -}; - -template -class ContextImpl : public ChannelHandlerContext, - public InboundChannelHandlerContext, - public OutboundChannelHandlerContext, - public PipelineContext { - public: - typedef typename H::rin Rin; - typedef typename H::rout Rout; - typedef typename H::win Win; - typedef typename H::wout Wout; - - template - explicit ContextImpl(P* pipeline, HandlerArg&& handlerArg) - : pipeline_(pipeline), - handler_(std::forward(handlerArg)) { - handler_.attachPipeline(this); - } - - ~ContextImpl() { - handler_.detachPipeline(this); - } - - H* getHandler() { - return &handler_; - } - - // PipelineContext overrides - void setNextIn(PipelineContext* ctx) override { - auto nextIn = dynamic_cast*>(ctx); - if (nextIn) { - nextIn_ = nextIn; - } else { - throw std::invalid_argument("wrong type in setNextIn"); - } - } - - void setNextOut(PipelineContext* ctx) override { - auto nextOut = dynamic_cast*>(ctx); - if (nextOut) { - nextOut_ = nextOut; - } else { - throw std::invalid_argument("wrong type in setNextOut"); - } - } - - void attachTransport() override { - typename P::DestructorGuard dg(static_cast(pipeline_)); - handler_.attachTransport(this); - } - - void detachTransport() override { - typename P::DestructorGuard dg(static_cast(pipeline_)); - handler_.detachTransport(this); - } - - // ChannelHandlerContext overrides - void fireRead(Rout msg) override { - typename P::DestructorGuard dg(static_cast(pipeline_)); - if (nextIn_) { - nextIn_->read(std::forward(msg)); - } else { - LOG(WARNING) << "read reached end of pipeline"; - } - } - - void fireReadEOF() override { - typename P::DestructorGuard dg(static_cast(pipeline_)); - if (nextIn_) { - nextIn_->readEOF(); - } else { - LOG(WARNING) << "readEOF reached end of pipeline"; - } - } - - void fireReadException(exception_wrapper e) override { - typename P::DestructorGuard dg(static_cast(pipeline_)); - if (nextIn_) { - nextIn_->readException(std::move(e)); - } else { - LOG(WARNING) << "readException reached end of pipeline"; - } - } - - Future fireWrite(Wout msg) override { - typename P::DestructorGuard dg(static_cast(pipeline_)); - if (nextOut_) { - return nextOut_->write(std::forward(msg)); - } else { - LOG(WARNING) << "write reached end of pipeline"; - return makeFuture(); - } - } - - Future fireClose() override { - typename P::DestructorGuard dg(static_cast(pipeline_)); - if (nextOut_) { - return nextOut_->close(); - } else { - LOG(WARNING) << "close reached end of pipeline"; - return makeFuture(); - } - } - - std::shared_ptr getTransport() override { - return pipeline_->getTransport(); - } - - void setWriteFlags(WriteFlags flags) override { - pipeline_->setWriteFlags(flags); - } - - WriteFlags getWriteFlags() override { - return pipeline_->getWriteFlags(); - } - - void setReadBufferSettings( - uint64_t minAvailable, - uint64_t allocationSize) override { - pipeline_->setReadBufferSettings(minAvailable, allocationSize); - } - - std::pair getReadBufferSettings() override { - return pipeline_->getReadBufferSettings(); - } - - // InboundChannelHandlerContext overrides - void read(Rin msg) override { - typename P::DestructorGuard dg(static_cast(pipeline_)); - handler_.read(this, std::forward(msg)); - } - - void readEOF() override { - typename P::DestructorGuard dg(static_cast(pipeline_)); - handler_.readEOF(this); - } - - void readException(exception_wrapper e) override { - typename P::DestructorGuard dg(static_cast(pipeline_)); - handler_.readException(this, std::move(e)); - } - - // OutboundChannelHandlerContext overrides - Future write(Win msg) override { - typename P::DestructorGuard dg(static_cast(pipeline_)); - return handler_.write(this, std::forward(msg)); - } - - Future close() override { - typename P::DestructorGuard dg(static_cast(pipeline_)); - return handler_.close(this); - } - - private: - P* pipeline_; - H handler_; - InboundChannelHandlerContext* nextIn_{nullptr}; - OutboundChannelHandlerContext* nextOut_{nullptr}; -}; - -}} diff --git a/folly/experimental/wangle/channel/ChannelPipeline.h b/folly/experimental/wangle/channel/ChannelPipeline.h deleted file mode 100644 index 386caec3..00000000 --- a/folly/experimental/wangle/channel/ChannelPipeline.h +++ /dev/null @@ -1,342 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -namespace folly { namespace wangle { - -/* - * R is the inbound type, i.e. inbound calls start with pipeline.read(R) - * W is the outbound type, i.e. outbound calls start with pipeline.write(W) - */ -template -class ChannelPipeline; - -template -class ChannelPipeline : public DelayedDestruction { - public: - ChannelPipeline() {} - ~ChannelPipeline() {} - - std::shared_ptr getTransport() { - return transport_; - } - - void setWriteFlags(WriteFlags flags) { - writeFlags_ = flags; - } - - WriteFlags getWriteFlags() { - return writeFlags_; - } - - void setReadBufferSettings(uint64_t minAvailable, uint64_t allocationSize) { - readBufferSettings_ = std::make_pair(minAvailable, allocationSize); - } - - std::pair getReadBufferSettings() { - return readBufferSettings_; - } - - void read(R msg) { - front_->read(std::forward(msg)); - } - - void readEOF() { - front_->readEOF(); - } - - void readException(exception_wrapper e) { - front_->readException(std::move(e)); - } - - Future write(W msg) { - return back_->write(std::forward(msg)); - } - - Future close() { - return back_->close(); - } - - template - ChannelPipeline& addBack(H&& handler) { - ctxs_.push_back(folly::make_unique>( - this, std::forward(handler))); - return *this; - } - - template - ChannelPipeline& addFront(H&& handler) { - ctxs_.insert( - ctxs_.begin(), - folly::make_unique>( - this, - std::forward(handler))); - return *this; - } - - template - H* getHandler(int i) { - auto ctx = dynamic_cast*>(ctxs_[i].get()); - CHECK(ctx); - return ctx->getHandler(); - } - - void finalize() { - finalizeHelper(); - InboundChannelHandlerContext* front; - front_ = dynamic_cast*>( - ctxs_.front().get()); - if (!front_) { - throw std::invalid_argument("wrong type for first handler"); - } - } - - protected: - explicit ChannelPipeline(bool shouldFinalize) { - CHECK(!shouldFinalize); - } - - void finalizeHelper() { - if (ctxs_.empty()) { - return; - } - - for (int i = 0; i < ctxs_.size() - 1; i++) { - ctxs_[i]->link(ctxs_[i+1].get()); - } - - back_ = dynamic_cast*>(ctxs_.back().get()); - if (!back_) { - throw std::invalid_argument("wrong type for last handler"); - } - } - - PipelineContext* getLocalFront() { - return ctxs_.empty() ? nullptr : ctxs_.front().get(); - } - - static const bool is_end{true}; - - std::shared_ptr transport_; - WriteFlags writeFlags_{WriteFlags::NONE}; - std::pair readBufferSettings_{2048, 2048}; - - void attachPipeline() {} - - void attachTransport( - std::shared_ptr transport) { - transport_ = std::move(transport); - } - - void detachTransport() { - transport_ = nullptr; - } - - OutboundChannelHandlerContext* back_{nullptr}; - - private: - InboundChannelHandlerContext* front_{nullptr}; - std::vector> ctxs_; -}; - -template -class ChannelPipeline - : public ChannelPipeline { - protected: - template - ChannelPipeline( - bool shouldFinalize, - HandlerArg&& handlerArg, - HandlersArgs&&... handlersArgs) - : ChannelPipeline( - false, - std::forward(handlersArgs)...), - ctx_(this, std::forward(handlerArg)) { - if (shouldFinalize) { - finalize(); - } - } - - public: - template - explicit ChannelPipeline(HandlersArgs&&... handlersArgs) - : ChannelPipeline(true, std::forward(handlersArgs)...) {} - - ~ChannelPipeline() {} - - void destroy() override { } - - void read(R msg) { - typename ChannelPipeline::DestructorGuard dg( - static_cast(this)); - front_->read(std::forward(msg)); - } - - void readEOF() { - typename ChannelPipeline::DestructorGuard dg( - static_cast(this)); - front_->readEOF(); - } - - void readException(exception_wrapper e) { - typename ChannelPipeline::DestructorGuard dg( - static_cast(this)); - front_->readException(std::move(e)); - } - - Future write(W msg) { - typename ChannelPipeline::DestructorGuard dg( - static_cast(this)); - return back_->write(std::forward(msg)); - } - - Future close() { - typename ChannelPipeline::DestructorGuard dg( - static_cast(this)); - return back_->close(); - } - - void attachTransport( - std::shared_ptr transport) { - typename ChannelPipeline::DestructorGuard dg( - static_cast(this)); - CHECK((!ChannelPipeline::transport_)); - ChannelPipeline::attachTransport(std::move(transport)); - forEachCtx([&](PipelineContext* ctx){ - ctx->attachTransport(); - }); - } - - void detachTransport() { - typename ChannelPipeline::DestructorGuard dg( - static_cast(this)); - ChannelPipeline::detachTransport(); - forEachCtx([&](PipelineContext* ctx){ - ctx->detachTransport(); - }); - } - - std::shared_ptr getTransport() { - return ChannelPipeline::transport_; - } - - template - ChannelPipeline& addBack(H&& handler) { - ChannelPipeline::addBack(std::move(handler)); - return *this; - } - - template - ChannelPipeline& addFront(H&& handler) { - ctxs_.insert( - ctxs_.begin(), - folly::make_unique>( - this, - std::move(handler))); - return *this; - } - - template - H* getHandler(size_t i) { - if (i > ctxs_.size()) { - return ChannelPipeline::template getHandler( - i - (ctxs_.size() + 1)); - } else { - auto pctx = (i == ctxs_.size()) ? &ctx_ : ctxs_[i].get(); - auto ctx = dynamic_cast*>(pctx); - return ctx->getHandler(); - } - } - - void finalize() { - finalizeHelper(); - auto ctx = ctxs_.empty() ? &ctx_ : ctxs_.front().get(); - front_ = dynamic_cast*>(ctx); - if (!front_) { - throw std::invalid_argument("wrong type for first handler"); - } - } - - protected: - void finalizeHelper() { - ChannelPipeline::finalizeHelper(); - back_ = ChannelPipeline::back_; - if (!back_) { - auto is_end = ChannelPipeline::is_end; - CHECK(is_end); - back_ = dynamic_cast*>(&ctx_); - if (!back_) { - throw std::invalid_argument("wrong type for last handler"); - } - } - - if (!ctxs_.empty()) { - for (int i = 0; i < ctxs_.size() - 1; i++) { - ctxs_[i]->link(ctxs_[i+1].get()); - } - ctxs_.back()->link(&ctx_); - } - - auto nextFront = ChannelPipeline::getLocalFront(); - if (nextFront) { - ctx_.link(nextFront); - } - } - - PipelineContext* getLocalFront() { - return ctxs_.empty() ? &ctx_ : ctxs_.front().get(); - } - - static const bool is_end{false}; - InboundChannelHandlerContext* front_{nullptr}; - OutboundChannelHandlerContext* back_{nullptr}; - - private: - template - void forEachCtx(const F& func) { - for (auto& ctx : ctxs_) { - func(ctx.get()); - } - func(&ctx_); - } - - ContextImpl ctx_; - std::vector> ctxs_; -}; - -}} - -namespace folly { - -class AsyncSocket; - -template -class PipelineFactory { - public: - virtual Pipeline* newPipeline(std::shared_ptr) = 0; - virtual ~PipelineFactory() {} -}; - -} diff --git a/folly/experimental/wangle/channel/OutputBufferingHandler.h b/folly/experimental/wangle/channel/OutputBufferingHandler.h deleted file mode 100644 index 04d12d00..00000000 --- a/folly/experimental/wangle/channel/OutputBufferingHandler.h +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include - -namespace folly { namespace wangle { - -/* - * OutputBufferingHandler buffers writes in order to minimize syscalls. The - * transport will be written to once per event loop instead of on every write. - */ -class OutputBufferingHandler : public BytesToBytesHandler, - protected EventBase::LoopCallback { - public: - Future write(Context* ctx, std::unique_ptr buf) override { - CHECK(buf); - if (!queueSends_) { - return ctx->fireWrite(std::move(buf)); - } else { - ctx_ = ctx; - // Delay sends to optimize for fewer syscalls - if (!sends_) { - DCHECK(!isLoopCallbackScheduled()); - // Buffer all the sends, and call writev once per event loop. - sends_ = std::move(buf); - ctx->getTransport()->getEventBase()->runInLoop(this); - } else { - DCHECK(isLoopCallbackScheduled()); - sends_->prependChain(std::move(buf)); - } - Promise p; - auto f = p.getFuture(); - promises_.push_back(std::move(p)); - return f; - } - } - - void runLoopCallback() noexcept override { - MoveWrapper>> promises(std::move(promises_)); - ctx_->fireWrite(std::move(sends_)).then([promises](Try&& t) mutable { - try { - t.throwIfFailed(); - for (auto& p : *promises) { - p.setValue(); - } - } catch (...) { - for (auto& p : *promises) { - p.setException(std::current_exception()); - } - } - }); - } - - std::vector> promises_; - std::unique_ptr sends_{nullptr}; - bool queueSends_{true}; - Context* ctx_; -}; - -}} diff --git a/folly/experimental/wangle/channel/test/ChannelPipelineTest.cpp b/folly/experimental/wangle/channel/test/ChannelPipelineTest.cpp deleted file mode 100644 index cae1d063..00000000 --- a/folly/experimental/wangle/channel/test/ChannelPipelineTest.cpp +++ /dev/null @@ -1,251 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include - -using namespace folly; -using namespace folly::wangle; -using namespace testing; - -typedef StrictMock> IntHandler; -typedef ChannelHandlerPtr IntHandlerPtr; - -ACTION(FireRead) { - arg0->fireRead(arg1); -} - -ACTION(FireReadEOF) { - arg0->fireReadEOF(); -} - -ACTION(FireReadException) { - arg0->fireReadException(arg1); -} - -ACTION(FireWrite) { - arg0->fireWrite(arg1); -} - -ACTION(FireClose) { - arg0->fireClose(); -} - -// Test move only types, among other things -TEST(ChannelTest, RealHandlersCompile) { - EventBase eb; - auto socket = AsyncSocket::newSocket(&eb); - // static - { - ChannelPipeline, - AsyncSocketHandler, - OutputBufferingHandler> - pipeline{AsyncSocketHandler(socket), OutputBufferingHandler()}; - EXPECT_TRUE(pipeline.getHandler(0)); - EXPECT_TRUE(pipeline.getHandler(1)); - } - // dynamic - { - ChannelPipeline> pipeline; - pipeline - .addBack(AsyncSocketHandler(socket)) - .addBack(OutputBufferingHandler()) - .finalize(); - EXPECT_TRUE(pipeline.getHandler(0)); - EXPECT_TRUE(pipeline.getHandler(1)); - } -} - -// Test that handlers correctly fire the next handler when directed -TEST(ChannelTest, FireActions) { - IntHandler handler1; - IntHandler handler2; - - EXPECT_CALL(handler1, attachPipeline(_)); - EXPECT_CALL(handler2, attachPipeline(_)); - - ChannelPipeline - pipeline(&handler1, &handler2); - - EXPECT_CALL(handler1, read_(_, _)).WillOnce(FireRead()); - EXPECT_CALL(handler2, read_(_, _)).Times(1); - pipeline.read(1); - - EXPECT_CALL(handler1, readEOF(_)).WillOnce(FireReadEOF()); - EXPECT_CALL(handler2, readEOF(_)).Times(1); - pipeline.readEOF(); - - EXPECT_CALL(handler1, readException(_, _)).WillOnce(FireReadException()); - EXPECT_CALL(handler2, readException(_, _)).Times(1); - pipeline.readException(make_exception_wrapper("blah")); - - EXPECT_CALL(handler2, write_(_, _)).WillOnce(FireWrite()); - EXPECT_CALL(handler1, write_(_, _)).Times(1); - EXPECT_NO_THROW(pipeline.write(1).value()); - - EXPECT_CALL(handler2, close_(_)).WillOnce(FireClose()); - EXPECT_CALL(handler1, close_(_)).Times(1); - EXPECT_NO_THROW(pipeline.close().value()); - - EXPECT_CALL(handler1, detachPipeline(_)); - EXPECT_CALL(handler2, detachPipeline(_)); -} - -// Test that nothing bad happens when actions reach the end of the pipeline -// (a warning will be logged, however) -TEST(ChannelTest, ReachEndOfPipeline) { - IntHandler handler; - EXPECT_CALL(handler, attachPipeline(_)); - ChannelPipeline - pipeline(&handler); - - EXPECT_CALL(handler, read_(_, _)).WillOnce(FireRead()); - pipeline.read(1); - - EXPECT_CALL(handler, readEOF(_)).WillOnce(FireReadEOF()); - pipeline.readEOF(); - - EXPECT_CALL(handler, readException(_, _)).WillOnce(FireReadException()); - pipeline.readException(make_exception_wrapper("blah")); - - EXPECT_CALL(handler, write_(_, _)).WillOnce(FireWrite()); - EXPECT_NO_THROW(pipeline.write(1).value()); - - EXPECT_CALL(handler, close_(_)).WillOnce(FireClose()); - EXPECT_NO_THROW(pipeline.close().value()); - - EXPECT_CALL(handler, detachPipeline(_)); -} - -// Test having the last read handler turn around and write -TEST(ChannelTest, TurnAround) { - IntHandler handler1; - IntHandler handler2; - - EXPECT_CALL(handler1, attachPipeline(_)); - EXPECT_CALL(handler2, attachPipeline(_)); - - ChannelPipeline - pipeline(&handler1, &handler2); - - EXPECT_CALL(handler1, read_(_, _)).WillOnce(FireRead()); - EXPECT_CALL(handler2, read_(_, _)).WillOnce(FireWrite()); - EXPECT_CALL(handler1, write_(_, _)).Times(1); - pipeline.read(1); - - EXPECT_CALL(handler1, detachPipeline(_)); - EXPECT_CALL(handler2, detachPipeline(_)); -} - -TEST(ChannelTest, DynamicFireActions) { - IntHandler handler1, handler2, handler3; - EXPECT_CALL(handler2, attachPipeline(_)); - ChannelPipeline - pipeline(&handler2); - - EXPECT_CALL(handler1, attachPipeline(_)); - EXPECT_CALL(handler3, attachPipeline(_)); - - pipeline - .addFront(IntHandlerPtr(&handler1)) - .addBack(IntHandlerPtr(&handler3)) - .finalize(); - - EXPECT_TRUE(pipeline.getHandler(0)); - EXPECT_TRUE(pipeline.getHandler(1)); - EXPECT_TRUE(pipeline.getHandler(2)); - - EXPECT_CALL(handler1, read_(_, _)).WillOnce(FireRead()); - EXPECT_CALL(handler2, read_(_, _)).WillOnce(FireRead()); - EXPECT_CALL(handler3, read_(_, _)).Times(1); - pipeline.read(1); - - EXPECT_CALL(handler3, write_(_, _)).WillOnce(FireWrite()); - EXPECT_CALL(handler2, write_(_, _)).WillOnce(FireWrite()); - EXPECT_CALL(handler1, write_(_, _)).Times(1); - EXPECT_NO_THROW(pipeline.write(1).value()); - - EXPECT_CALL(handler1, detachPipeline(_)); - EXPECT_CALL(handler2, detachPipeline(_)); - EXPECT_CALL(handler3, detachPipeline(_)); -} - -template -class ConcreteChannelHandler : public ChannelHandler { - typedef typename ChannelHandler::Context Context; - public: - void read(Context* ctx, Rin msg) {} - Future write(Context* ctx, Win msg) { return makeFuture(); } -}; - -typedef ChannelHandlerAdapter StringHandler; -typedef ConcreteChannelHandler IntToStringHandler; -typedef ConcreteChannelHandler StringToIntHandler; - -TEST(ChannelPipeline, DynamicConstruction) { - { - ChannelPipeline pipeline; - EXPECT_THROW( - pipeline - .addBack(ChannelHandlerAdapter{}) - .finalize(), std::invalid_argument); - } - { - ChannelPipeline pipeline; - EXPECT_THROW( - pipeline - .addFront(ChannelHandlerAdapter{}) - .finalize(), - std::invalid_argument); - } - { - ChannelPipeline - pipeline{StringHandler(), StringHandler()}; - - // Exercise both addFront and addBack. Final pipeline is - // StI <-> ItS <-> StS <-> StS <-> StI <-> ItS - EXPECT_NO_THROW( - pipeline - .addFront(IntToStringHandler{}) - .addFront(StringToIntHandler{}) - .addBack(StringToIntHandler{}) - .addBack(IntToStringHandler{}) - .finalize()); - } -} - -TEST(ChannelPipeline, AttachTransport) { - IntHandler handler; - EXPECT_CALL(handler, attachPipeline(_)); - ChannelPipeline - pipeline(&handler); - - EventBase eb; - auto socket = AsyncSocket::newSocket(&eb); - - EXPECT_CALL(handler, attachTransport(_)); - pipeline.attachTransport(socket); - - EXPECT_CALL(handler, detachTransport(_)); - pipeline.detachTransport(); - - EXPECT_CALL(handler, detachPipeline(_)); -} diff --git a/folly/experimental/wangle/channel/test/MockChannelHandler.h b/folly/experimental/wangle/channel/test/MockChannelHandler.h deleted file mode 100644 index ddf511cb..00000000 --- a/folly/experimental/wangle/channel/test/MockChannelHandler.h +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace folly { namespace wangle { - -template -class MockChannelHandler : public ChannelHandler { - public: - typedef typename ChannelHandler::Context Context; - - MockChannelHandler() = default; - MockChannelHandler(MockChannelHandler&&) = default; - - MOCK_METHOD2_T(read_, void(Context*, Rin&)); - MOCK_METHOD1_T(readEOF, void(Context*)); - MOCK_METHOD2_T(readException, void(Context*, exception_wrapper)); - - MOCK_METHOD2_T(write_, void(Context*, Win&)); - MOCK_METHOD1_T(close_, void(Context*)); - - MOCK_METHOD1_T(attachPipeline, void(Context*)); - MOCK_METHOD1_T(attachTransport, void(Context*)); - MOCK_METHOD1_T(detachPipeline, void(Context*)); - MOCK_METHOD1_T(detachTransport, void(Context*)); - - void read(Context* ctx, Rin msg) { - read_(ctx, msg); - } - - Future write(Context* ctx, Win msg) override { - return makeFutureTry([&](){ - write_(ctx, msg); - }); - } - - Future close(Context* ctx) override { - return makeFutureTry([&](){ - close_(ctx); - }); - } -}; - -template -using MockChannelHandlerAdapter = MockChannelHandler; - -}} diff --git a/folly/experimental/wangle/channel/test/OutputBufferingHandlerTest.cpp b/folly/experimental/wangle/channel/test/OutputBufferingHandlerTest.cpp deleted file mode 100644 index 600a6a85..00000000 --- a/folly/experimental/wangle/channel/test/OutputBufferingHandlerTest.cpp +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include - -using namespace folly; -using namespace folly::wangle; -using namespace testing; - -typedef StrictMock>> -MockHandler; - -MATCHER_P(IOBufContains, str, "") { return arg->moveToFbString() == str; } - -TEST(OutputBufferingHandlerTest, Basic) { - MockHandler mockHandler; - EXPECT_CALL(mockHandler, attachPipeline(_)); - ChannelPipeline, - ChannelHandlerPtr, - OutputBufferingHandler> - pipeline(&mockHandler, OutputBufferingHandler{}); - - EventBase eb; - auto socket = AsyncSocket::newSocket(&eb); - EXPECT_CALL(mockHandler, attachTransport(_)); - pipeline.attachTransport(socket); - - // Buffering should prevent writes until the EB loops, and the writes should - // be batched into one write call. - auto f1 = pipeline.write(IOBuf::copyBuffer("hello")); - auto f2 = pipeline.write(IOBuf::copyBuffer("world")); - EXPECT_FALSE(f1.isReady()); - EXPECT_FALSE(f2.isReady()); - EXPECT_CALL(mockHandler, write_(_, IOBufContains("helloworld"))); - eb.loopOnce(); - EXPECT_TRUE(f1.isReady()); - EXPECT_TRUE(f2.isReady()); - EXPECT_CALL(mockHandler, detachPipeline(_)); -} diff --git a/folly/experimental/wangle/concurrent/BlockingQueue.h b/folly/experimental/wangle/concurrent/BlockingQueue.h deleted file mode 100644 index 08a1f703..00000000 --- a/folly/experimental/wangle/concurrent/BlockingQueue.h +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace folly { namespace wangle { - -template -class BlockingQueue { - public: - virtual ~BlockingQueue() {} - virtual void add(T item) = 0; - virtual void addWithPriority(T item, uint32_t priority) { - LOG_FIRST_N(WARNING, 1) << - "add(item, priority) called on a non-priority queue"; - add(std::move(item)); - } - virtual uint32_t getNumPriorities() { - LOG_FIRST_N(WARNING, 1) << - "getNumPriorities() called on a non-priority queue"; - return 1; - } - virtual T take() = 0; - virtual size_t size() = 0; -}; - -}} // folly::wangle diff --git a/folly/experimental/wangle/concurrent/CPUThreadPoolExecutor.cpp b/folly/experimental/wangle/concurrent/CPUThreadPoolExecutor.cpp deleted file mode 100644 index 9caf6bee..00000000 --- a/folly/experimental/wangle/concurrent/CPUThreadPoolExecutor.cpp +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace folly { namespace wangle { - -const size_t CPUThreadPoolExecutor::kDefaultMaxQueueSize = 1 << 18; -const size_t CPUThreadPoolExecutor::kDefaultNumPriorities = 2; - -CPUThreadPoolExecutor::CPUThreadPoolExecutor( - size_t numThreads, - std::unique_ptr> taskQueue, - std::shared_ptr threadFactory) - : ThreadPoolExecutor(numThreads, std::move(threadFactory)), - taskQueue_(std::move(taskQueue)) { - addThreads(numThreads); - CHECK(threadList_.get().size() == numThreads); -} - -CPUThreadPoolExecutor::CPUThreadPoolExecutor( - size_t numThreads, - std::shared_ptr threadFactory) - : CPUThreadPoolExecutor( - numThreads, - folly::make_unique>( - CPUThreadPoolExecutor::kDefaultMaxQueueSize), - std::move(threadFactory)) {} - -CPUThreadPoolExecutor::CPUThreadPoolExecutor(size_t numThreads) - : CPUThreadPoolExecutor( - numThreads, - std::make_shared("CPUThreadPool")) {} - -CPUThreadPoolExecutor::CPUThreadPoolExecutor( - size_t numThreads, - uint32_t numPriorities, - std::shared_ptr threadFactory) - : CPUThreadPoolExecutor( - numThreads, - folly::make_unique>( - numPriorities, - CPUThreadPoolExecutor::kDefaultMaxQueueSize), - std::move(threadFactory)) {} - -CPUThreadPoolExecutor::~CPUThreadPoolExecutor() { - stop(); - CHECK(threadsToStop_ == 0); -} - -void CPUThreadPoolExecutor::add(Func func) { - add(std::move(func), std::chrono::milliseconds(0)); -} - -void CPUThreadPoolExecutor::add( - Func func, - std::chrono::milliseconds expiration, - Func expireCallback) { - // TODO handle enqueue failure, here and in other add() callsites - taskQueue_->add( - CPUTask(std::move(func), expiration, std::move(expireCallback))); -} - -void CPUThreadPoolExecutor::add(Func func, uint32_t priority) { - add(std::move(func), priority, std::chrono::milliseconds(0)); -} - -void CPUThreadPoolExecutor::add( - Func func, - uint32_t priority, - std::chrono::milliseconds expiration, - Func expireCallback) { - CHECK(priority < getNumPriorities()); - taskQueue_->addWithPriority( - CPUTask(std::move(func), expiration, std::move(expireCallback)), - priority); -} - -uint32_t CPUThreadPoolExecutor::getNumPriorities() const { - return taskQueue_->getNumPriorities(); -} - -BlockingQueue* -CPUThreadPoolExecutor::getTaskQueue() { - return taskQueue_.get(); -} - -void CPUThreadPoolExecutor::threadRun(std::shared_ptr thread) { - thread->startupBaton.post(); - while (1) { - auto task = taskQueue_->take(); - if (UNLIKELY(task.poison)) { - CHECK(threadsToStop_-- > 0); - stoppedThreads_.add(thread); - return; - } else { - runTask(thread, std::move(task)); - } - - if (UNLIKELY(threadsToStop_ > 0 && !isJoin_)) { - if (--threadsToStop_ >= 0) { - stoppedThreads_.add(thread); - return; - } else { - threadsToStop_++; - } - } - } -} - -void CPUThreadPoolExecutor::stopThreads(size_t n) { - CHECK(stoppedThreads_.size() == 0); - threadsToStop_ = n; - for (size_t i = 0; i < n; i++) { - taskQueue_->add(CPUTask()); - } -} - -uint64_t CPUThreadPoolExecutor::getPendingTaskCount() { - return taskQueue_->size(); -} - -}} // folly::wangle diff --git a/folly/experimental/wangle/concurrent/CPUThreadPoolExecutor.h b/folly/experimental/wangle/concurrent/CPUThreadPoolExecutor.h deleted file mode 100644 index b7e88685..00000000 --- a/folly/experimental/wangle/concurrent/CPUThreadPoolExecutor.h +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace folly { namespace wangle { - -class CPUThreadPoolExecutor : public ThreadPoolExecutor { - public: - struct CPUTask; - - explicit CPUThreadPoolExecutor( - size_t numThreads, - std::unique_ptr> taskQueue, - std::shared_ptr threadFactory = - std::make_shared("CPUThreadPool")); - - explicit CPUThreadPoolExecutor(size_t numThreads); - - explicit CPUThreadPoolExecutor( - size_t numThreads, - std::shared_ptr threadFactory); - - explicit CPUThreadPoolExecutor( - size_t numThreads, - uint32_t numPriorities, - std::shared_ptr threadFactory = - std::make_shared("CPUThreadPool")); - - ~CPUThreadPoolExecutor(); - - void add(Func func) override; - void add( - Func func, - std::chrono::milliseconds expiration, - Func expireCallback = nullptr) override; - - void add(Func func, uint32_t priority); - void add( - Func func, - uint32_t priority, - std::chrono::milliseconds expiration, - Func expireCallback = nullptr); - - uint32_t getNumPriorities() const; - - struct CPUTask : public ThreadPoolExecutor::Task { - // Must be noexcept move constructible so it can be used in MPMCQueue - explicit CPUTask( - Func&& f, - std::chrono::milliseconds expiration, - Func&& expireCallback) - : Task(std::move(f), expiration, std::move(expireCallback)), - poison(false) {} - CPUTask() - : Task(nullptr, std::chrono::milliseconds(0), nullptr), - poison(true) {} - CPUTask(CPUTask&& o) noexcept : Task(std::move(o)), poison(o.poison) {} - CPUTask(const CPUTask&) = default; - CPUTask& operator=(const CPUTask&) = default; - bool poison; - }; - - static const size_t kDefaultMaxQueueSize; - static const size_t kDefaultNumPriorities; - - protected: - BlockingQueue* getTaskQueue(); - - private: - void threadRun(ThreadPtr thread) override; - void stopThreads(size_t n) override; - uint64_t getPendingTaskCount() override; - - std::unique_ptr> taskQueue_; - std::atomic threadsToStop_{0}; -}; - -}} // folly::wangle diff --git a/folly/experimental/wangle/concurrent/Codel.cpp b/folly/experimental/wangle/concurrent/Codel.cpp deleted file mode 100644 index 527058ce..00000000 --- a/folly/experimental/wangle/concurrent/Codel.cpp +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#ifndef NO_LIB_GFLAGS - #include - DEFINE_int32(codel_interval, 100, - "Codel default interval time in ms"); - DEFINE_int32(codel_target_delay, 5, - "Target codel queueing delay in ms"); -#endif - -namespace folly { namespace wangle { - -#ifdef NO_LIB_GFLAGS - int32_t FLAGS_codel_interval = 100; - int32_t FLAGS_codel_target_delay = 5; -#endif - -Codel::Codel() - : codelMinDelay_(0), - codelIntervalTime_(std::chrono::steady_clock::now()), - codelResetDelay_(true), - overloaded_(false) {} - -bool Codel::overloaded(std::chrono::microseconds delay) { - bool ret = false; - auto now = std::chrono::steady_clock::now(); - - // Avoid another thread updating the value at the same time we are using it - // to calculate the overloaded state - auto minDelay = codelMinDelay_; - - if (now > codelIntervalTime_ && - (!codelResetDelay_.load(std::memory_order_acquire) - && !codelResetDelay_.exchange(true))) { - codelIntervalTime_ = now + std::chrono::milliseconds(FLAGS_codel_interval); - - if (minDelay > std::chrono::milliseconds(FLAGS_codel_target_delay)) { - overloaded_ = true; - } else { - overloaded_ = false; - } - } - // Care must be taken that only a single thread resets codelMinDelay_, - // and that it happens after the interval reset above - if (codelResetDelay_.load(std::memory_order_acquire) && - codelResetDelay_.exchange(false)) { - codelMinDelay_ = delay; - // More than one request must come in during an interval before codel - // starts dropping requests - return false; - } else if(delay < codelMinDelay_) { - codelMinDelay_ = delay; - } - - if (overloaded_ && - delay > std::chrono::milliseconds(FLAGS_codel_target_delay * 2)) { - ret = true; - } - - return ret; - -} - -int Codel::getLoad() { - return std::min(100, (int)codelMinDelay_.count() / - (2 * FLAGS_codel_target_delay)); -} - -int Codel::getMinDelay() { - return (int) codelMinDelay_.count(); -} - -}} //namespace diff --git a/folly/experimental/wangle/concurrent/Codel.h b/folly/experimental/wangle/concurrent/Codel.h deleted file mode 100644 index 16f0205b..00000000 --- a/folly/experimental/wangle/concurrent/Codel.h +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace folly { namespace wangle { - -/* Codel algorithm implementation: - * http://en.wikipedia.org/wiki/CoDel - * - * Algorithm modified slightly: Instead of changing the interval time - * based on the average min delay, instead we use an alternate timeout - * for each task if the min delay during the interval period is too - * high. - * - * This was found to have better latency metrics than changing the - * window size, since we can communicate with the sender via thrift - * instead of only via the tcp window size congestion control, as in TCP. - */ -class Codel { - - public: - Codel(); - - // Given a delay, returns wether the codel algorithm would - // reject a queued request with this delay. - // - // Internally, it also keeps track of the interval - bool overloaded(std::chrono::microseconds delay); - - // Get the queue load, as seen by the codel algorithm - // Gives a rough guess at how bad the queue delay is. - // - // Return: 0 = no delay, 100 = At the queueing limit - int getLoad(); - - int getMinDelay(); - - private: - std::chrono::microseconds codelMinDelay_; - std::chrono::time_point codelIntervalTime_; - - // flag to make overloaded() thread-safe, since we only want - // to reset the delay once per time period - std::atomic codelResetDelay_; - - bool overloaded_; -}; - -}} // Namespace diff --git a/folly/experimental/wangle/concurrent/FutureExecutor.h b/folly/experimental/wangle/concurrent/FutureExecutor.h deleted file mode 100644 index 8aeedff8..00000000 --- a/folly/experimental/wangle/concurrent/FutureExecutor.h +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include - -namespace folly { namespace wangle { - -template -class FutureExecutor : public ExecutorImpl { - public: - template - explicit FutureExecutor(Args&&... args) - : ExecutorImpl(std::forward(args)...) {} - - /* - * Given a function func that returns a Future, adds that function to the - * contained Executor and returns a Future which will be fulfilled with - * func's result once it has been executed. - * - * For example: auto f = futureExecutor.addFuture([](){ - * return doAsyncWorkAndReturnAFuture(); - * }); - */ - template - typename std::enable_if::type>::value, - typename std::result_of::type>::type - addFuture(F func) { - typedef typename std::result_of::type::value_type T; - Promise promise; - auto future = promise.getFuture(); - auto movePromise = folly::makeMoveWrapper(std::move(promise)); - auto moveFunc = folly::makeMoveWrapper(std::move(func)); - ExecutorImpl::add([movePromise, moveFunc] () mutable { - (*moveFunc)().then([movePromise] (Try&& t) mutable { - movePromise->fulfilTry(std::move(t)); - }); - }); - return future; - } - - /* - * Similar to addFuture above, but takes a func that returns some non-Future - * type T. - * - * For example: auto f = futureExecutor.addFuture([]() { - * return 42; - * }); - */ - template - typename std::enable_if::type>::value, - Future::type>>::type - addFuture(F func) { - typedef typename std::result_of::type T; - Promise promise; - auto future = promise.getFuture(); - auto movePromise = folly::makeMoveWrapper(std::move(promise)); - auto moveFunc = folly::makeMoveWrapper(std::move(func)); - ExecutorImpl::add([movePromise, moveFunc] () mutable { - movePromise->fulfil(std::move(*moveFunc)); - }); - return future; - } -}; - -}} diff --git a/folly/experimental/wangle/concurrent/GlobalExecutor.cpp b/folly/experimental/wangle/concurrent/GlobalExecutor.cpp deleted file mode 100644 index b0efd4f2..00000000 --- a/folly/experimental/wangle/concurrent/GlobalExecutor.cpp +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -using namespace folly; -using namespace folly::wangle; - -namespace { - -Singleton globalIOThreadPoolSingleton( - "GlobalIOThreadPool", - [](){ - return new IOThreadPoolExecutor( - sysconf(_SC_NPROCESSORS_ONLN), - std::make_shared("GlobalIOThreadPool")); - }); - -} - -namespace folly { namespace wangle { - -IOExecutor* getIOExecutor() { - auto singleton = IOExecutor::getSingleton(); - auto executor = singleton->load(); - while (!executor) { - IOExecutor* nullIOExecutor = nullptr; - singleton->compare_exchange_strong( - nullIOExecutor, - Singleton::get("GlobalIOThreadPool")); - executor = singleton->load(); - } - return executor; -} - -void setIOExecutor(IOExecutor* executor) { - IOExecutor::getSingleton()->store(executor); -} - -}} // folly::wangle diff --git a/folly/experimental/wangle/concurrent/GlobalExecutor.h b/folly/experimental/wangle/concurrent/GlobalExecutor.h deleted file mode 100644 index cac76be8..00000000 --- a/folly/experimental/wangle/concurrent/GlobalExecutor.h +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -namespace folly { namespace wangle { - -class IOExecutor; - -// Retrieve the global IOExecutor. If there is none, a default -// IOThreadPoolExecutor will be constructed and returned. -IOExecutor* getIOExecutor(); - -// Set an IOExecutor to be the global IOExecutor which will be returned by -// subsequent calls to getIOExecutor(). IOExecutors will uninstall themselves -// as global when they are destructed. -void setIOExecutor(IOExecutor* executor); - -}} diff --git a/folly/experimental/wangle/concurrent/IOExecutor.cpp b/folly/experimental/wangle/concurrent/IOExecutor.cpp deleted file mode 100644 index d3985c99..00000000 --- a/folly/experimental/wangle/concurrent/IOExecutor.cpp +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include -#include - -using folly::Singleton; -using folly::wangle::IOExecutor; - -namespace { - -Singleton> globalIOExecutorSingleton( - "GlobalIOExecutor", - [](){ - return new std::atomic(nullptr); - }); - -} - -namespace folly { namespace wangle { - -IOExecutor::~IOExecutor() { - auto thisCopy = this; - try { - getSingleton()->compare_exchange_strong(thisCopy, nullptr); - } catch (const std::runtime_error& e) { - // The global IOExecutor singleton was already destructed so doesn't need to - // be restored. Ignore. - } -} - -std::atomic* IOExecutor::getSingleton() { - return Singleton>::get("GlobalIOExecutor"); -} - -}} // folly::wangle diff --git a/folly/experimental/wangle/concurrent/IOExecutor.h b/folly/experimental/wangle/concurrent/IOExecutor.h deleted file mode 100644 index 14eb6643..00000000 --- a/folly/experimental/wangle/concurrent/IOExecutor.h +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace folly { -class EventBase; -} - -namespace folly { namespace wangle { - -// An IOExecutor is an executor that operates on at least one EventBase. One of -// these EventBases should be accessible via getEventBase(). The event base -// returned by a call to getEventBase() is implementation dependent. -// -// Note that IOExecutors don't necessarily loop on the base themselves - for -// instance, EventBase itself is an IOExecutor but doesn't drive itself. -// -// Implementations of IOExecutor are eligible to become the global IO executor, -// returned on every call to getIOExecutor(), via setIOExecutor(). -// These functions are declared in GlobalExecutor.h -// -// If getIOExecutor is called and none has been set, a default global -// IOThreadPoolExecutor will be created and returned. -class IOExecutor : public virtual Executor { - public: - virtual ~IOExecutor(); - virtual EventBase* getEventBase() = 0; - - private: - static std::atomic* getSingleton(); - friend IOExecutor* getIOExecutor(); - friend void setIOExecutor(IOExecutor* executor); -}; - -}} diff --git a/folly/experimental/wangle/concurrent/IOThreadPoolExecutor.cpp b/folly/experimental/wangle/concurrent/IOThreadPoolExecutor.cpp deleted file mode 100644 index f1fa3904..00000000 --- a/folly/experimental/wangle/concurrent/IOThreadPoolExecutor.cpp +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include - -#include - -namespace folly { namespace wangle { - -using folly::detail::MemoryIdler; - -/* Class that will free jemalloc caches and madvise the stack away - * if the event loop is unused for some period of time - */ -class MemoryIdlerTimeout - : public AsyncTimeout , public EventBase::LoopCallback { - public: - explicit MemoryIdlerTimeout(EventBase* b) : AsyncTimeout(b), base_(b) {} - - virtual void timeoutExpired() noexcept { - idled = true; - } - - virtual void runLoopCallback() noexcept { - if (idled) { - MemoryIdler::flushLocalMallocCaches(); - MemoryIdler::unmapUnusedStack(MemoryIdler::kDefaultStackToRetain); - - idled = false; - } else { - std::chrono::steady_clock::duration idleTimeout = - MemoryIdler::defaultIdleTimeout.load( - std::memory_order_acquire); - - idleTimeout = MemoryIdler::getVariationTimeout(idleTimeout); - - scheduleTimeout(std::chrono::duration_cast( - idleTimeout).count()); - } - - // reschedule this callback for the next event loop. - base_->runBeforeLoop(this); - } - private: - EventBase* base_; - bool idled{false}; -} ; - -IOThreadPoolExecutor::IOThreadPoolExecutor( - size_t numThreads, - std::shared_ptr threadFactory) - : ThreadPoolExecutor(numThreads, std::move(threadFactory)), - nextThread_(0) { - addThreads(numThreads); - CHECK(threadList_.get().size() == numThreads); -} - -IOThreadPoolExecutor::~IOThreadPoolExecutor() { - stop(); -} - -void IOThreadPoolExecutor::add(Func func) { - add(std::move(func), std::chrono::milliseconds(0)); -} - -void IOThreadPoolExecutor::add( - Func func, - std::chrono::milliseconds expiration, - Func expireCallback) { - RWSpinLock::ReadHolder{&threadListLock_}; - if (threadList_.get().empty()) { - throw std::runtime_error("No threads available"); - } - auto ioThread = pickThread(); - - auto moveTask = folly::makeMoveWrapper( - Task(std::move(func), expiration, std::move(expireCallback))); - auto wrappedFunc = [ioThread, moveTask] () mutable { - runTask(ioThread, std::move(*moveTask)); - ioThread->pendingTasks--; - }; - - ioThread->pendingTasks++; - if (!ioThread->eventBase->runInEventBaseThread(std::move(wrappedFunc))) { - ioThread->pendingTasks--; - throw std::runtime_error("Unable to run func in event base thread"); - } -} - -std::shared_ptr -IOThreadPoolExecutor::pickThread() { - if (*thisThread_) { - return *thisThread_; - } - auto thread = threadList_.get()[nextThread_++ % threadList_.get().size()]; - return std::static_pointer_cast(thread); -} - -EventBase* IOThreadPoolExecutor::getEventBase() { - return pickThread()->eventBase; -} - -std::shared_ptr -IOThreadPoolExecutor::makeThread() { - return std::make_shared(this); -} - -void IOThreadPoolExecutor::threadRun(ThreadPtr thread) { - const auto ioThread = std::static_pointer_cast(thread); - ioThread->eventBase = - folly::EventBaseManager::get()->getEventBase(); - thisThread_.reset(new std::shared_ptr(ioThread)); - - auto idler = new MemoryIdlerTimeout(ioThread->eventBase); - ioThread->eventBase->runBeforeLoop(idler); - - thread->startupBaton.post(); - while (ioThread->shouldRun) { - ioThread->eventBase->loopForever(); - } - if (isJoin_) { - while (ioThread->pendingTasks > 0) { - ioThread->eventBase->loopOnce(); - } - } - stoppedThreads_.add(ioThread); -} - -// threadListLock_ is writelocked -void IOThreadPoolExecutor::stopThreads(size_t n) { - for (size_t i = 0; i < n; i++) { - const auto ioThread = std::static_pointer_cast( - threadList_.get()[i]); - ioThread->shouldRun = false; - ioThread->eventBase->terminateLoopSoon(); - } -} - -std::vector IOThreadPoolExecutor::getEventBases() { - std::vector bases; - RWSpinLock::ReadHolder{&threadListLock_}; - for (const auto& thread : threadList_.get()) { - auto ioThread = std::static_pointer_cast(thread); - bases.push_back(ioThread->eventBase); - } - return bases; -} - -// threadListLock_ is readlocked -uint64_t IOThreadPoolExecutor::getPendingTaskCount() { - uint64_t count = 0; - for (const auto& thread : threadList_.get()) { - auto ioThread = std::static_pointer_cast(thread); - size_t pendingTasks = ioThread->pendingTasks; - if (pendingTasks > 0 && !ioThread->idle) { - pendingTasks--; - } - count += pendingTasks; - } - return count; -} - -}} // folly::wangle diff --git a/folly/experimental/wangle/concurrent/IOThreadPoolExecutor.h b/folly/experimental/wangle/concurrent/IOThreadPoolExecutor.h deleted file mode 100644 index 9196ef6d..00000000 --- a/folly/experimental/wangle/concurrent/IOThreadPoolExecutor.h +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -namespace folly { namespace wangle { - -// N.B. For this thread pool, stop() behaves like join() because outstanding -// tasks belong to the event base and will be executed upon its destruction. -class IOThreadPoolExecutor : public ThreadPoolExecutor, public IOExecutor { - public: - explicit IOThreadPoolExecutor( - size_t numThreads, - std::shared_ptr threadFactory = - std::make_shared("IOThreadPool")); - - ~IOThreadPoolExecutor(); - - void add(Func func) override; - void add( - Func func, - std::chrono::milliseconds expiration, - Func expireCallback = nullptr) override; - - EventBase* getEventBase() override; - - std::vector getEventBases(); - - private: - struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING IOThread : public Thread { - IOThread(IOThreadPoolExecutor* pool) - : Thread(pool), - shouldRun(true), - pendingTasks(0) {}; - std::atomic shouldRun; - std::atomic pendingTasks; - EventBase* eventBase; - }; - - ThreadPtr makeThread() override; - std::shared_ptr pickThread(); - void threadRun(ThreadPtr thread) override; - void stopThreads(size_t n) override; - uint64_t getPendingTaskCount() override; - - size_t nextThread_; - ThreadLocal> thisThread_; -}; - -}} // folly::wangle diff --git a/folly/experimental/wangle/concurrent/LifoSemMPMCQueue.h b/folly/experimental/wangle/concurrent/LifoSemMPMCQueue.h deleted file mode 100644 index ff499991..00000000 --- a/folly/experimental/wangle/concurrent/LifoSemMPMCQueue.h +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include -#include - -namespace folly { namespace wangle { - -template -class LifoSemMPMCQueue : public BlockingQueue { - public: - explicit LifoSemMPMCQueue(size_t max_capacity) : queue_(max_capacity) {} - - void add(T item) override { - if (!queue_.write(std::move(item))) { - throw std::runtime_error("LifoSemMPMCQueue full, can't add item"); - } - sem_.post(); - } - - T take() override { - T item; - while (!queue_.read(item)) { - sem_.wait(); - } - return item; - } - - size_t capacity() { - return queue_.capacity(); - } - - size_t size() override { - return queue_.size(); - } - - private: - LifoSem sem_; - MPMCQueue queue_; -}; - -}} // folly::wangle diff --git a/folly/experimental/wangle/concurrent/NamedThreadFactory.h b/folly/experimental/wangle/concurrent/NamedThreadFactory.h deleted file mode 100644 index 2bb5579b..00000000 --- a/folly/experimental/wangle/concurrent/NamedThreadFactory.h +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -#include -#include -#include -#include - -namespace folly { namespace wangle { - -class NamedThreadFactory : public ThreadFactory { - public: - explicit NamedThreadFactory(folly::StringPiece prefix) - : prefix_(prefix.str()), suffix_(0) {} - - std::thread newThread(Func&& func) override { - auto thread = std::thread(std::move(func)); - folly::setThreadName( - thread.native_handle(), - folly::to(prefix_, suffix_++)); - return thread; - } - - void setNamePrefix(folly::StringPiece prefix) { - prefix_ = prefix.str(); - } - - std::string getNamePrefix() { - return prefix_; - } - - private: - std::string prefix_; - std::atomic suffix_; -}; - -}} // folly::wangle diff --git a/folly/experimental/wangle/concurrent/PriorityLifoSemMPMCQueue.h b/folly/experimental/wangle/concurrent/PriorityLifoSemMPMCQueue.h deleted file mode 100644 index 65500f58..00000000 --- a/folly/experimental/wangle/concurrent/PriorityLifoSemMPMCQueue.h +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include -#include - -namespace folly { namespace wangle { - -template -class PriorityLifoSemMPMCQueue : public BlockingQueue { - public: - explicit PriorityLifoSemMPMCQueue(uint32_t numPriorities, size_t capacity) { - CHECK(numPriorities > 0); - queues_.reserve(numPriorities); - for (int i = 0; i < numPriorities; i++) { - queues_.push_back(MPMCQueue(capacity)); - } - } - - uint32_t getNumPriorities() override { - return queues_.size(); - } - - // Add at lowest priority by default - void add(T item) override { - addWithPriority(std::move(item), 0); - } - - void addWithPriority(T item, uint32_t priority) override { - CHECK(priority < queues_.size()); - if (!queues_[priority].write(std::move(item))) { - throw std::runtime_error("LifoSemMPMCQueue full, can't add item"); - } - sem_.post(); - } - - T take() override { - T item; - while (true) { - for (auto it = queues_.rbegin(); it != queues_.rend(); it++) { - if (it->read(item)) { - return item; - } - } - sem_.wait(); - } - } - - size_t size() override { - size_t size = 0; - for (auto& q : queues_) { - size += q.size(); - } - return size; - } - - private: - LifoSem sem_; - std::vector> queues_; -}; - -}} // folly::wangle diff --git a/folly/experimental/wangle/concurrent/ThreadFactory.h b/folly/experimental/wangle/concurrent/ThreadFactory.h deleted file mode 100644 index 7654fbc9..00000000 --- a/folly/experimental/wangle/concurrent/ThreadFactory.h +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include - -#include - -namespace folly { namespace wangle { - -class ThreadFactory { - public: - virtual ~ThreadFactory() {} - virtual std::thread newThread(Func&& func) = 0; -}; - -}} // folly::wangle diff --git a/folly/experimental/wangle/concurrent/ThreadPoolExecutor.cpp b/folly/experimental/wangle/concurrent/ThreadPoolExecutor.cpp deleted file mode 100644 index 74890d7b..00000000 --- a/folly/experimental/wangle/concurrent/ThreadPoolExecutor.cpp +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace folly { namespace wangle { - -ThreadPoolExecutor::ThreadPoolExecutor( - size_t numThreads, - std::shared_ptr threadFactory) - : threadFactory_(std::move(threadFactory)), - taskStatsSubject_(std::make_shared>()) {} - -ThreadPoolExecutor::~ThreadPoolExecutor() { - CHECK(threadList_.get().size() == 0); -} - -ThreadPoolExecutor::Task::Task( - Func&& func, - std::chrono::milliseconds expiration, - Func&& expireCallback) - : func_(std::move(func)), - expiration_(expiration), - expireCallback_(std::move(expireCallback)) { - // Assume that the task in enqueued on creation - enqueueTime_ = std::chrono::steady_clock::now(); -} - -void ThreadPoolExecutor::runTask( - const ThreadPtr& thread, - Task&& task) { - thread->idle = false; - auto startTime = std::chrono::steady_clock::now(); - task.stats_.waitTime = startTime - task.enqueueTime_; - if (task.expiration_ > std::chrono::milliseconds(0) && - task.stats_.waitTime >= task.expiration_) { - task.stats_.expired = true; - if (task.expireCallback_ != nullptr) { - task.expireCallback_(); - } - } else { - try { - task.func_(); - } catch (const std::exception& e) { - LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled " << - typeid(e).name() << " exception: " << e.what(); - } catch (...) { - LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled non-exception " - "object"; - } - task.stats_.runTime = std::chrono::steady_clock::now() - startTime; - } - thread->idle = true; - thread->taskStatsSubject->onNext(std::move(task.stats_)); -} - -size_t ThreadPoolExecutor::numThreads() { - RWSpinLock::ReadHolder{&threadListLock_}; - return threadList_.get().size(); -} - -void ThreadPoolExecutor::setNumThreads(size_t n) { - RWSpinLock::WriteHolder{&threadListLock_}; - const auto current = threadList_.get().size(); - if (n > current ) { - addThreads(n - current); - } else if (n < current) { - removeThreads(current - n, true); - } - CHECK(threadList_.get().size() == n); -} - -// threadListLock_ is writelocked -void ThreadPoolExecutor::addThreads(size_t n) { - std::vector newThreads; - for (size_t i = 0; i < n; i++) { - newThreads.push_back(makeThread()); - } - for (auto& thread : newThreads) { - // TODO need a notion of failing to create the thread - // and then handling for that case - thread->handle = threadFactory_->newThread( - std::bind(&ThreadPoolExecutor::threadRun, this, thread)); - threadList_.add(thread); - } - for (auto& thread : newThreads) { - thread->startupBaton.wait(); - } -} - -// threadListLock_ is writelocked -void ThreadPoolExecutor::removeThreads(size_t n, bool isJoin) { - CHECK(n <= threadList_.get().size()); - CHECK(stoppedThreads_.size() == 0); - isJoin_ = isJoin; - stopThreads(n); - for (size_t i = 0; i < n; i++) { - auto thread = stoppedThreads_.take(); - thread->handle.join(); - threadList_.remove(thread); - } - CHECK(stoppedThreads_.size() == 0); -} - -void ThreadPoolExecutor::stop() { - RWSpinLock::WriteHolder{&threadListLock_}; - removeThreads(threadList_.get().size(), false); - CHECK(threadList_.get().size() == 0); -} - -void ThreadPoolExecutor::join() { - RWSpinLock::WriteHolder{&threadListLock_}; - removeThreads(threadList_.get().size(), true); - CHECK(threadList_.get().size() == 0); -} - -ThreadPoolExecutor::PoolStats ThreadPoolExecutor::getPoolStats() { - RWSpinLock::ReadHolder{&threadListLock_}; - ThreadPoolExecutor::PoolStats stats; - stats.threadCount = threadList_.get().size(); - for (auto thread : threadList_.get()) { - if (thread->idle) { - stats.idleThreadCount++; - } else { - stats.activeThreadCount++; - } - } - stats.pendingTaskCount = getPendingTaskCount(); - stats.totalTaskCount = stats.pendingTaskCount + stats.activeThreadCount; - return stats; -} - -std::atomic ThreadPoolExecutor::Thread::nextId(0); - -void ThreadPoolExecutor::StoppedThreadQueue::add( - ThreadPoolExecutor::ThreadPtr item) { - std::lock_guard guard(mutex_); - queue_.push(std::move(item)); - sem_.post(); -} - -ThreadPoolExecutor::ThreadPtr ThreadPoolExecutor::StoppedThreadQueue::take() { - while(1) { - { - std::lock_guard guard(mutex_); - if (queue_.size() > 0) { - auto item = std::move(queue_.front()); - queue_.pop(); - return item; - } - } - sem_.wait(); - } -} - -size_t ThreadPoolExecutor::StoppedThreadQueue::size() { - std::lock_guard guard(mutex_); - return queue_.size(); -} - -}} // folly::wangle diff --git a/folly/experimental/wangle/concurrent/ThreadPoolExecutor.h b/folly/experimental/wangle/concurrent/ThreadPoolExecutor.h deleted file mode 100644 index 88aa1bc7..00000000 --- a/folly/experimental/wangle/concurrent/ThreadPoolExecutor.h +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include - -namespace folly { namespace wangle { - -class ThreadPoolExecutor : public virtual Executor { - public: - explicit ThreadPoolExecutor( - size_t numThreads, - std::shared_ptr threadFactory); - - ~ThreadPoolExecutor(); - - virtual void add(Func func) override = 0; - virtual void add( - Func func, - std::chrono::milliseconds expiration, - Func expireCallback) = 0; - - void setThreadFactory(std::shared_ptr threadFactory) { - CHECK(numThreads() == 0); - threadFactory_ = std::move(threadFactory); - } - - std::shared_ptr getThreadFactory(void) { - return threadFactory_; - } - - size_t numThreads(); - void setNumThreads(size_t numThreads); - /* - * stop() is best effort - there is no guarantee that unexecuted tasks won't - * be executed before it returns. Specifically, IOThreadPoolExecutor's stop() - * behaves like join(). - */ - void stop(); - void join(); - - struct PoolStats { - PoolStats() : threadCount(0), idleThreadCount(0), activeThreadCount(0), - pendingTaskCount(0), totalTaskCount(0) {} - size_t threadCount, idleThreadCount, activeThreadCount; - uint64_t pendingTaskCount, totalTaskCount; - }; - - PoolStats getPoolStats(); - - struct TaskStats { - TaskStats() : expired(false), waitTime(0), runTime(0) {} - bool expired; - std::chrono::nanoseconds waitTime; - std::chrono::nanoseconds runTime; - }; - - Subscription subscribeToTaskStats( - const ObserverPtr& observer) { - return taskStatsSubject_->subscribe(observer); - } - - protected: - // Prerequisite: threadListLock_ writelocked - void addThreads(size_t n); - // Prerequisite: threadListLock_ writelocked - void removeThreads(size_t n, bool isJoin); - - struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING Thread { - explicit Thread(ThreadPoolExecutor* pool) - : id(nextId++), - handle(), - idle(true), - taskStatsSubject(pool->taskStatsSubject_) {} - - virtual ~Thread() {} - - static std::atomic nextId; - uint64_t id; - std::thread handle; - bool idle; - Baton<> startupBaton; - std::shared_ptr> taskStatsSubject; - }; - - typedef std::shared_ptr ThreadPtr; - - struct Task { - explicit Task( - Func&& func, - std::chrono::milliseconds expiration, - Func&& expireCallback); - Func func_; - TaskStats stats_; - std::chrono::steady_clock::time_point enqueueTime_; - std::chrono::milliseconds expiration_; - Func expireCallback_; - }; - - static void runTask(const ThreadPtr& thread, Task&& task); - - // The function that will be bound to pool threads. It must call - // thread->startupBaton.post() when it's ready to consume work. - virtual void threadRun(ThreadPtr thread) = 0; - - // Stop n threads and put their ThreadPtrs in the threadsStopped_ queue - // Prerequisite: threadListLock_ writelocked - virtual void stopThreads(size_t n) = 0; - - // Create a suitable Thread struct - virtual ThreadPtr makeThread() { - return std::make_shared(this); - } - - // Prerequisite: threadListLock_ readlocked - virtual uint64_t getPendingTaskCount() = 0; - - class ThreadList { - public: - void add(const ThreadPtr& state) { - auto it = std::lower_bound(vec_.begin(), vec_.end(), state, compare); - vec_.insert(it, state); - } - - void remove(const ThreadPtr& state) { - auto itPair = std::equal_range(vec_.begin(), vec_.end(), state, compare); - CHECK(itPair.first != vec_.end()); - CHECK(std::next(itPair.first) == itPair.second); - vec_.erase(itPair.first); - } - - const std::vector& get() const { - return vec_; - } - - private: - static bool compare(const ThreadPtr& ts1, const ThreadPtr& ts2) { - return ts1->id < ts2->id; - } - - std::vector vec_; - }; - - class StoppedThreadQueue : public BlockingQueue { - public: - void add(ThreadPtr item) override; - ThreadPtr take() override; - size_t size() override; - - private: - LifoSem sem_; - std::mutex mutex_; - std::queue queue_; - }; - - std::shared_ptr threadFactory_; - ThreadList threadList_; - RWSpinLock threadListLock_; - StoppedThreadQueue stoppedThreads_; - std::atomic isJoin_; // whether the current downsizing is a join - - std::shared_ptr> taskStatsSubject_; -}; - -}} // folly::wangle diff --git a/folly/experimental/wangle/concurrent/test/CodelTest.cpp b/folly/experimental/wangle/concurrent/test/CodelTest.cpp deleted file mode 100644 index f13dc02b..00000000 --- a/folly/experimental/wangle/concurrent/test/CodelTest.cpp +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include - -TEST(CodelTest, Basic) { - using std::chrono::milliseconds; - folly::wangle::Codel c; - std::this_thread::sleep_for(milliseconds(110)); - // This interval is overloaded - EXPECT_FALSE(c.overloaded(milliseconds(100))); - std::this_thread::sleep_for(milliseconds(90)); - // At least two requests must happen in an interval before they will fail - EXPECT_FALSE(c.overloaded(milliseconds(50))); - EXPECT_TRUE(c.overloaded(milliseconds(50))); - std::this_thread::sleep_for(milliseconds(110)); - // Previous interval is overloaded, but 2ms isn't enough to fail - EXPECT_FALSE(c.overloaded(milliseconds(2))); - std::this_thread::sleep_for(milliseconds(90)); - // 20 ms > target interval * 2 - EXPECT_TRUE(c.overloaded(milliseconds(20))); -} diff --git a/folly/experimental/wangle/concurrent/test/GlobalExecutorTest.cpp b/folly/experimental/wangle/concurrent/test/GlobalExecutorTest.cpp deleted file mode 100644 index f0f678de..00000000 --- a/folly/experimental/wangle/concurrent/test/GlobalExecutorTest.cpp +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -using namespace folly::wangle; - -TEST(GlobalExecutorTest, GlobalIOExecutor) { - class DummyExecutor : public IOExecutor { - public: - void add(folly::Func f) override { - count++; - } - folly::EventBase* getEventBase() override { - return nullptr; - } - int count{0}; - }; - - auto f = [](){}; - - // Don't explode, we should create the default global IOExecutor lazily here. - getIOExecutor()->add(f); - - { - DummyExecutor dummy; - setIOExecutor(&dummy); - getIOExecutor()->add(f); - // Make sure we were properly installed. - EXPECT_EQ(1, dummy.count); - } - - // Don't explode, we should restore the default global IOExecutor when dummy - // is destructed. - getIOExecutor()->add(f); -} diff --git a/folly/experimental/wangle/concurrent/test/ThreadPoolExecutorTest.cpp b/folly/experimental/wangle/concurrent/test/ThreadPoolExecutorTest.cpp deleted file mode 100644 index 3e431073..00000000 --- a/folly/experimental/wangle/concurrent/test/ThreadPoolExecutorTest.cpp +++ /dev/null @@ -1,320 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include - -using namespace folly::wangle; -using namespace std::chrono; - -static folly::Func burnMs(uint64_t ms) { - return [ms]() { std::this_thread::sleep_for(milliseconds(ms)); }; -} - -template -static void basic() { - // Create and destroy - TPE tpe(10); -} - -TEST(ThreadPoolExecutorTest, CPUBasic) { - basic(); -} - -TEST(IOThreadPoolExecutorTest, IOBasic) { - basic(); -} - -template -static void resize() { - TPE tpe(100); - EXPECT_EQ(100, tpe.numThreads()); - tpe.setNumThreads(50); - EXPECT_EQ(50, tpe.numThreads()); - tpe.setNumThreads(150); - EXPECT_EQ(150, tpe.numThreads()); -} - -TEST(ThreadPoolExecutorTest, CPUResize) { - resize(); -} - -TEST(ThreadPoolExecutorTest, IOResize) { - resize(); -} - -template -static void stop() { - TPE tpe(1); - std::atomic completed(0); - auto f = [&](){ - burnMs(10)(); - completed++; - }; - for (int i = 0; i < 1000; i++) { - tpe.add(f); - } - tpe.stop(); - EXPECT_GT(1000, completed); -} - -// IOThreadPoolExecutor's stop() behaves like join(). Outstanding tasks belong -// to the event base, will be executed upon its destruction, and cannot be -// taken back. -template <> -void stop() { - IOThreadPoolExecutor tpe(1); - std::atomic completed(0); - auto f = [&](){ - burnMs(10)(); - completed++; - }; - for (int i = 0; i < 10; i++) { - tpe.add(f); - } - tpe.stop(); - EXPECT_EQ(10, completed); -} - -TEST(ThreadPoolExecutorTest, CPUStop) { - stop(); -} - -TEST(ThreadPoolExecutorTest, IOStop) { - stop(); -} - -template -static void join() { - TPE tpe(10); - std::atomic completed(0); - auto f = [&](){ - burnMs(1)(); - completed++; - }; - for (int i = 0; i < 1000; i++) { - tpe.add(f); - } - tpe.join(); - EXPECT_EQ(1000, completed); -} - -TEST(ThreadPoolExecutorTest, CPUJoin) { - join(); -} - -TEST(ThreadPoolExecutorTest, IOJoin) { - join(); -} - -template -static void resizeUnderLoad() { - TPE tpe(10); - std::atomic completed(0); - auto f = [&](){ - burnMs(1)(); - completed++; - }; - for (int i = 0; i < 1000; i++) { - tpe.add(f); - } - tpe.setNumThreads(5); - tpe.setNumThreads(15); - tpe.join(); - EXPECT_EQ(1000, completed); -} - -TEST(ThreadPoolExecutorTest, CPUResizeUnderLoad) { - resizeUnderLoad(); -} - -TEST(ThreadPoolExecutorTest, IOResizeUnderLoad) { - resizeUnderLoad(); -} - -template -static void poolStats() { - folly::Baton<> startBaton, endBaton; - TPE tpe(1); - auto stats = tpe.getPoolStats(); - EXPECT_EQ(1, stats.threadCount); - EXPECT_EQ(1, stats.idleThreadCount); - EXPECT_EQ(0, stats.activeThreadCount); - EXPECT_EQ(0, stats.pendingTaskCount); - EXPECT_EQ(0, stats.totalTaskCount); - tpe.add([&](){ startBaton.post(); endBaton.wait(); }); - tpe.add([&](){}); - startBaton.wait(); - stats = tpe.getPoolStats(); - EXPECT_EQ(1, stats.threadCount); - EXPECT_EQ(0, stats.idleThreadCount); - EXPECT_EQ(1, stats.activeThreadCount); - EXPECT_EQ(1, stats.pendingTaskCount); - EXPECT_EQ(2, stats.totalTaskCount); - endBaton.post(); -} - -TEST(ThreadPoolExecutorTest, CPUPoolStats) { - poolStats(); -} - -TEST(ThreadPoolExecutorTest, IOPoolStats) { - poolStats(); -} - -template -static void taskStats() { - TPE tpe(1); - std::atomic c(0); - auto s = tpe.subscribeToTaskStats( - Observer::create( - [&](ThreadPoolExecutor::TaskStats stats) { - int i = c++; - EXPECT_LT(milliseconds(0), stats.runTime); - if (i == 1) { - EXPECT_LT(milliseconds(0), stats.waitTime); - } - })); - tpe.add(burnMs(10)); - tpe.add(burnMs(10)); - tpe.join(); - EXPECT_EQ(2, c); -} - -TEST(ThreadPoolExecutorTest, CPUTaskStats) { - taskStats(); -} - -TEST(ThreadPoolExecutorTest, IOTaskStats) { - taskStats(); -} - -template -static void expiration() { - TPE tpe(1); - std::atomic statCbCount(0); - auto s = tpe.subscribeToTaskStats( - Observer::create( - [&](ThreadPoolExecutor::TaskStats stats) { - int i = statCbCount++; - if (i == 0) { - EXPECT_FALSE(stats.expired); - } else if (i == 1) { - EXPECT_TRUE(stats.expired); - } else { - FAIL(); - } - })); - std::atomic expireCbCount(0); - auto expireCb = [&] () { expireCbCount++; }; - tpe.add(burnMs(10), seconds(60), expireCb); - tpe.add(burnMs(10), milliseconds(10), expireCb); - tpe.join(); - EXPECT_EQ(2, statCbCount); - EXPECT_EQ(1, expireCbCount); -} - -TEST(ThreadPoolExecutorTest, CPUExpiration) { - expiration(); -} - -TEST(ThreadPoolExecutorTest, IOExpiration) { - expiration(); -} - -template -static void futureExecutor() { - FutureExecutor fe(2); - std::atomic c{0}; - fe.addFuture([] () { return makeFuture(42); }).then( - [&] (Try&& t) { - c++; - EXPECT_EQ(42, t.value()); - }); - fe.addFuture([] () { return 100; }).then( - [&] (Try&& t) { - c++; - EXPECT_EQ(100, t.value()); - }); - fe.addFuture([] () { return makeFuture(); }).then( - [&] (Try&& t) { - c++; - EXPECT_NO_THROW(t.value()); - }); - fe.addFuture([] () { return; }).then( - [&] (Try&& t) { - c++; - EXPECT_NO_THROW(t.value()); - }); - fe.addFuture([] () { throw std::runtime_error("oops"); }).then( - [&] (Try&& t) { - c++; - EXPECT_THROW(t.value(), std::runtime_error); - }); - // Test doing actual async work - folly::Baton<> baton; - fe.addFuture([&] () { - auto p = std::make_shared>(); - std::thread t([p](){ - burnMs(10)(); - p->setValue(42); - }); - t.detach(); - return p->getFuture(); - }).then([&] (Try&& t) { - EXPECT_EQ(42, t.value()); - c++; - baton.post(); - }); - baton.wait(); - fe.join(); - EXPECT_EQ(6, c); -} - -TEST(ThreadPoolExecutorTest, CPUFuturePool) { - futureExecutor(); -} - -TEST(ThreadPoolExecutorTest, IOFuturePool) { - futureExecutor(); -} - -TEST(ThreadPoolExecutorTest, PriorityPreemptionTest) { - bool tookLopri = false; - auto completed = 0; - auto hipri = [&] { - EXPECT_FALSE(tookLopri); - completed++; - }; - auto lopri = [&] { - tookLopri = true; - completed++; - }; - CPUThreadPoolExecutor pool(0, 2); - for (int i = 0; i < 50; i++) { - pool.add(lopri, 0); - } - for (int i = 0; i < 50; i++) { - pool.add(hipri, 1); - } - pool.setNumThreads(1); - pool.join(); - EXPECT_EQ(100, completed); -} diff --git a/folly/experimental/wangle/rx/Dummy.cpp b/folly/experimental/wangle/rx/Dummy.cpp deleted file mode 100644 index 02a58d4f..00000000 --- a/folly/experimental/wangle/rx/Dummy.cpp +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// fbbuild is too dumb to know that .h files in the directory affect -// our project, unless we have a .cpp file in the target, in the same -// directory. diff --git a/folly/experimental/wangle/rx/Observable.h b/folly/experimental/wangle/rx/Observable.h deleted file mode 100644 index e9a6196e..00000000 --- a/folly/experimental/wangle/rx/Observable.h +++ /dev/null @@ -1,284 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -namespace folly { namespace wangle { - -template -class Observable { - public: - Observable() : nextSubscriptionId_{1} {} - - // TODO perhaps we want to provide this #5283229 - Observable(Observable&& other) = delete; - - virtual ~Observable() { - if (unsubscriber_) { - unsubscriber_->disable(); - } - } - - // The next three methods subscribe the given Observer to this Observable. - // - // If these are called within an Observer callback, the new observer will not - // get the current update but will get subsequent updates. - // - // subscribe() returns a Subscription object. The observer will continue to - // get updates until the Subscription is destroyed. - // - // observe(ObserverPtr) creates an indefinite subscription - // - // observe(Observer*) also creates an indefinite subscription, but the - // caller is responsible for ensuring that the given Observer outlives this - // Observable. This might be useful in high performance environments where - // allocations must be kept to a minimum. Template parameter InlineObservers - // specifies how many observers can been subscribed inline without any - // allocations (it's just the size of a folly::small_vector). - virtual Subscription subscribe(ObserverPtr observer) { - return subscribeImpl(observer, false); - } - - virtual void observe(ObserverPtr observer) { - subscribeImpl(observer, true); - } - - virtual void observe(Observer* observer) { - if (inCallback_ && *inCallback_) { - if (!newObservers_) { - newObservers_.reset(new ObserverList()); - } - newObservers_->push_back(observer); - } else { - RWSpinLock::WriteHolder{&observersLock_}; - observers_.push_back(observer); - } - } - - // TODO unobserve(ObserverPtr), unobserve(Observer*) - - /// Returns a new Observable that will call back on the given Scheduler. - /// The returned Observable must outlive the parent Observable. - - // This and subscribeOn should maybe just be a first-class feature of an - // Observable, rather than making new ones whose lifetimes are tied to their - // parents. In that case it'd return a reference to this object for - // chaining. - ObservablePtr observeOn(SchedulerPtr scheduler) { - // you're right Hannes, if we have Observable::create we don't need this - // helper class. - struct ViaSubject : public Observable - { - ViaSubject(SchedulerPtr sched, - Observable* obs) - : scheduler_(sched), observable_(obs) - {} - - Subscription subscribe(ObserverPtr o) override { - return observable_->subscribe( - Observer::create( - [=](T val) { scheduler_->add([o, val] { o->onNext(val); }); }, - [=](Error e) { scheduler_->add([o, e] { o->onError(e); }); }, - [=]() { scheduler_->add([o] { o->onCompleted(); }); })); - } - - protected: - SchedulerPtr scheduler_; - Observable* observable_; - }; - - return std::make_shared(scheduler, this); - } - - /// Returns a new Observable that will subscribe to this parent Observable - /// via the given Scheduler. This can be subtle and confusing at first, see - /// http://www.introtorx.com/Content/v1.0.10621.0/15_SchedulingAndThreading.html#SubscribeOnObserveOn - std::unique_ptr subscribeOn(SchedulerPtr scheduler) { - struct Subject_ : public Subject { - public: - Subject_(SchedulerPtr s, Observable* o) : scheduler_(s), observable_(o) { - } - - Subscription subscribe(ObserverPtr o) { - scheduler_->add([=] { - observable_->subscribe(o); - }); - return Subscription(nullptr, 0); // TODO - } - - protected: - SchedulerPtr scheduler_; - Observable* observable_; - }; - - return folly::make_unique(scheduler, this); - } - - protected: - // Safely execute an operation on each observer. F must take a single - // Observer* as its argument. - template - void forEachObserver(F f) { - if (UNLIKELY(!inCallback_)) { - inCallback_.reset(new bool{false}); - } - CHECK(!(*inCallback_)); - *inCallback_ = true; - - { - RWSpinLock::ReadHolder rh(observersLock_); - for (auto o : observers_) { - f(o); - } - - for (auto& kv : subscribers_) { - f(kv.second.get()); - } - } - - if (UNLIKELY((newObservers_ && !newObservers_->empty()) || - (newSubscribers_ && !newSubscribers_->empty()) || - (oldSubscribers_ && !oldSubscribers_->empty()))) { - { - RWSpinLock::WriteHolder wh(observersLock_); - if (newObservers_) { - for (auto observer : *(newObservers_)) { - observers_.push_back(observer); - } - newObservers_->clear(); - } - if (newSubscribers_) { - for (auto& kv : *(newSubscribers_)) { - subscribers_.insert(std::move(kv)); - } - newSubscribers_->clear(); - } - if (oldSubscribers_) { - for (auto id : *(oldSubscribers_)) { - subscribers_.erase(id); - } - oldSubscribers_->clear(); - } - } - } - *inCallback_ = false; - } - - private: - Subscription subscribeImpl(ObserverPtr observer, bool indefinite) { - auto subscription = makeSubscription(indefinite); - typename SubscriberMap::value_type kv{subscription.id_, std::move(observer)}; - if (inCallback_ && *inCallback_) { - if (!newSubscribers_) { - newSubscribers_.reset(new SubscriberMap()); - } - newSubscribers_->insert(std::move(kv)); - } else { - RWSpinLock::WriteHolder{&observersLock_}; - subscribers_.insert(std::move(kv)); - } - return subscription; - } - - class Unsubscriber { - public: - explicit Unsubscriber(Observable* observable) : observable_(observable) { - CHECK(observable_); - } - - void unsubscribe(uint64_t id) { - CHECK(id > 0); - RWSpinLock::ReadHolder guard(lock_); - if (observable_) { - observable_->unsubscribe(id); - } - } - - void disable() { - RWSpinLock::WriteHolder guard(lock_); - observable_ = nullptr; - } - - private: - RWSpinLock lock_; - Observable* observable_; - }; - - std::shared_ptr unsubscriber_{nullptr}; - MicroSpinLock unsubscriberLock_{0}; - - friend class Subscription; - - void unsubscribe(uint64_t id) { - if (inCallback_ && *inCallback_) { - if (!oldSubscribers_) { - oldSubscribers_.reset(new std::vector()); - } - if (newSubscribers_) { - auto it = newSubscribers_->find(id); - if (it != newSubscribers_->end()) { - newSubscribers_->erase(it); - return; - } - } - oldSubscribers_->push_back(id); - } else { - RWSpinLock::WriteHolder{&observersLock_}; - subscribers_.erase(id); - } - } - - Subscription makeSubscription(bool indefinite) { - if (indefinite) { - return Subscription(nullptr, nextSubscriptionId_++); - } else { - if (!unsubscriber_) { - std::lock_guard guard(unsubscriberLock_); - if (!unsubscriber_) { - unsubscriber_ = std::make_shared(this); - } - } - return Subscription(unsubscriber_, nextSubscriptionId_++); - } - } - - std::atomic nextSubscriptionId_; - RWSpinLock observersLock_; - folly::ThreadLocalPtr inCallback_; - - typedef folly::small_vector*, InlineObservers> ObserverList; - ObserverList observers_; - folly::ThreadLocalPtr newObservers_; - - typedef std::map> SubscriberMap; - SubscriberMap subscribers_; - folly::ThreadLocalPtr newSubscribers_; - folly::ThreadLocalPtr> oldSubscribers_; -}; - -}} diff --git a/folly/experimental/wangle/rx/Observer.h b/folly/experimental/wangle/rx/Observer.h deleted file mode 100644 index cfe49dd9..00000000 --- a/folly/experimental/wangle/rx/Observer.h +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include - -namespace folly { namespace wangle { - -template class FunctionObserver; - -/// Observer interface. You can subclass it, or you can just use create() -/// to use std::functions. -template -struct Observer { - // These are what it means to be an Observer. - virtual void onNext(const T&) = 0; - virtual void onError(Error) = 0; - virtual void onCompleted() = 0; - - virtual ~Observer() = default; - - /// Create an Observer with std::function callbacks. Handy to make ad-hoc - /// Observers with lambdas. - /// - /// Templated for maximum perfect forwarding flexibility, but ultimately - /// whatever you pass in has to implicitly become a std::function for the - /// same signature as onNext(), onError(), and onCompleted() respectively. - /// (see the FunctionObserver typedefs) - template - static std::unique_ptr create( - N&& onNextFn, E&& onErrorFn, C&& onCompletedFn) - { - return folly::make_unique>( - std::forward(onNextFn), - std::forward(onErrorFn), - std::forward(onCompletedFn)); - } - - /// Create an Observer with only onNext and onError callbacks. - /// onCompleted will just be a no-op. - template - static std::unique_ptr create(N&& onNextFn, E&& onErrorFn) { - return folly::make_unique>( - std::forward(onNextFn), - std::forward(onErrorFn), - nullptr); - } - - /// Create an Observer with only an onNext callback. - /// onError and onCompleted will just be no-ops. - template - static std::unique_ptr create(N&& onNextFn) { - return folly::make_unique>( - std::forward(onNextFn), - nullptr, - nullptr); - } -}; - -/// An observer that uses std::function callbacks. You don't really want to -/// make one of these directly - instead use the Observer::create() methods. -template -struct FunctionObserver : public Observer { - typedef std::function OnNext; - typedef std::function OnError; - typedef std::function OnCompleted; - - /// We don't need any fancy overloads of this constructor because that's - /// what Observer::create() is for. - template - FunctionObserver(N&& n, E&& e, C&& c) - : onNext_(std::forward(n)), - onError_(std::forward(e)), - onCompleted_(std::forward(c)) - {} - - void onNext(const T& val) override { - if (onNext_) onNext_(val); - } - - void onError(Error e) override { - if (onError_) onError_(e); - } - - void onCompleted() override { - if (onCompleted_) onCompleted_(); - } - - protected: - OnNext onNext_; - OnError onError_; - OnCompleted onCompleted_; -}; - -}} diff --git a/folly/experimental/wangle/rx/README b/folly/experimental/wangle/rx/README deleted file mode 100644 index ee170f35..00000000 --- a/folly/experimental/wangle/rx/README +++ /dev/null @@ -1,36 +0,0 @@ -Rx is a pattern for "functional reactive programming" that started at -Microsoft in C#, and has been reimplemented in various languages, notably -RxJava for JVM languages. - -It is basically the plural of Futures (a la Wangle). - - - singular | plural - +---------------------------------+----------------------------------- - sync | Foo getData() | std::vector getData() - async | wangle::Future getData() | wangle::Observable getData() - - -For more on Rx, I recommend these resources: - -Netflix blog post (RxJava): http://techblog.netflix.com/2013/02/rxjava-netflix-api.html -Introduction to Rx eBook (C#): http://www.introtorx.com/content/v1.0.10621.0/01_WhyRx.html -The RxJava wiki: https://github.com/Netflix/RxJava/wiki -Netflix QCon presentation: http://www.infoq.com/presentations/netflix-functional-rx -https://rx.codeplex.com/ - -There are open source C++ implementations, I haven't looked at them. They -might be the best way to go rather than writing it NIH-style. I mostly did it -as an exercise, to think through how closely we might want to integrate -something like this with Wangle, and to get a feel for how it works in C++. - -I haven't even tried to support move-only data in this version. I'm on the -fence about the usage of shared_ptr. Subject is underdeveloped. A whole rich -set of operations is obviously missing. I haven't decided how to handle -subscriptions (and therefore cancellation), but I'm pretty sure C#'s -"Disposable" is thoroughly un-C++ (opposite of RAII). So for now subscribe -returns nothing at all and you can't cancel anything ever. The whole thing is -probably riddled with lifetime corner case bugs that will come out like a -swarm of angry bees as soon as someone tries an infinite sequence, or tries to -partially observe a long sequence. I'm pretty sure subscribeOn has a bug that -I haven't tracked down yet. diff --git a/folly/experimental/wangle/rx/Subject.h b/folly/experimental/wangle/rx/Subject.h deleted file mode 100644 index 6ff04c0e..00000000 --- a/folly/experimental/wangle/rx/Subject.h +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace folly { namespace wangle { - -/// Subject interface. A Subject is both an Observable and an Observer. There -/// is a default implementation of the Observer methods that just forwards the -/// observed events to the Subject's observers. -template -struct Subject : public Observable, public Observer { - void onNext(const T& val) override { - this->forEachObserver([&](Observer* o){ - o->onNext(val); - }); - } - void onError(Error e) override { - this->forEachObserver([&](Observer* o){ - o->onError(e); - }); - } - void onCompleted() override { - this->forEachObserver([](Observer* o){ - o->onCompleted(); - }); - } -}; - -}} diff --git a/folly/experimental/wangle/rx/Subscription.h b/folly/experimental/wangle/rx/Subscription.h deleted file mode 100644 index 7c058e23..00000000 --- a/folly/experimental/wangle/rx/Subscription.h +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace folly { namespace wangle { - -template -class Subscription { - public: - Subscription() {} - - Subscription(const Subscription&) = delete; - - Subscription(Subscription&& other) noexcept { - *this = std::move(other); - } - - Subscription& operator=(Subscription&& other) noexcept { - unsubscribe(); - unsubscriber_ = std::move(other.unsubscriber_); - id_ = other.id_; - other.unsubscriber_ = nullptr; - other.id_ = 0; - return *this; - } - - ~Subscription() { - unsubscribe(); - } - - private: - typedef typename Observable::Unsubscriber Unsubscriber; - - Subscription(std::shared_ptr unsubscriber, uint64_t id) - : unsubscriber_(std::move(unsubscriber)), id_(id) { - CHECK(id_ > 0); - } - - void unsubscribe() { - if (unsubscriber_ && id_ > 0) { - unsubscriber_->unsubscribe(id_); - id_ = 0; - unsubscriber_ = nullptr; - } - } - - std::shared_ptr unsubscriber_; - uint64_t id_{0}; - - friend class Observable; -}; - -}} diff --git a/folly/experimental/wangle/rx/test/RxBenchmark.cpp b/folly/experimental/wangle/rx/test/RxBenchmark.cpp deleted file mode 100644 index c4123100..00000000 --- a/folly/experimental/wangle/rx/test/RxBenchmark.cpp +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include - -using namespace folly::wangle; -using folly::BenchmarkSuspender; - -static std::unique_ptr> makeObserver() { - return Observer::create([&] (int x) {}); -} - -void subscribeImpl(uint iters, int N, bool countUnsubscribe) { - for (uint iter = 0; iter < iters; iter++) { - BenchmarkSuspender bs; - Subject subject; - std::vector>> observers; - std::vector> subscriptions; - subscriptions.reserve(N); - for (int i = 0; i < N; i++) { - observers.push_back(makeObserver()); - } - bs.dismiss(); - for (int i = 0; i < N; i++) { - subscriptions.push_back(subject.subscribe(std::move(observers[i]))); - } - if (countUnsubscribe) { - subscriptions.clear(); - } - bs.rehire(); - } -} - -void subscribeAndUnsubscribe(uint iters, int N) { - subscribeImpl(iters, N, true); -} - -void subscribe(uint iters, int N) { - subscribeImpl(iters, N, false); -} - -void observe(uint iters, int N) { - for (uint iter = 0; iter < iters; iter++) { - BenchmarkSuspender bs; - Subject subject; - std::vector>> observers; - for (int i = 0; i < N; i++) { - observers.push_back(makeObserver()); - } - bs.dismiss(); - for (int i = 0; i < N; i++) { - subject.observe(std::move(observers[i])); - } - bs.rehire(); - } -} - -void inlineObserve(uint iters, int N) { - for (uint iter = 0; iter < iters; iter++) { - BenchmarkSuspender bs; - Subject subject; - std::vector*> observers; - for (int i = 0; i < N; i++) { - observers.push_back(makeObserver().release()); - } - bs.dismiss(); - for (int i = 0; i < N; i++) { - subject.observe(observers[i]); - } - bs.rehire(); - for (int i = 0; i < N; i++) { - delete observers[i]; - } - } -} - -void notifySubscribers(uint iters, int N) { - for (uint iter = 0; iter < iters; iter++) { - BenchmarkSuspender bs; - Subject subject; - std::vector>> observers; - std::vector> subscriptions; - subscriptions.reserve(N); - for (int i = 0; i < N; i++) { - observers.push_back(makeObserver()); - } - for (int i = 0; i < N; i++) { - subscriptions.push_back(subject.subscribe(std::move(observers[i]))); - } - bs.dismiss(); - subject.onNext(42); - bs.rehire(); - } -} - -void notifyInlineObservers(uint iters, int N) { - for (uint iter = 0; iter < iters; iter++) { - BenchmarkSuspender bs; - Subject subject; - std::vector*> observers; - for (int i = 0; i < N; i++) { - observers.push_back(makeObserver().release()); - } - for (int i = 0; i < N; i++) { - subject.observe(observers[i]); - } - bs.dismiss(); - subject.onNext(42); - bs.rehire(); - } -} - -BENCHMARK_PARAM(subscribeAndUnsubscribe, 1); -BENCHMARK_RELATIVE_PARAM(subscribe, 1); -BENCHMARK_RELATIVE_PARAM(observe, 1); -BENCHMARK_RELATIVE_PARAM(inlineObserve, 1); - -BENCHMARK_DRAW_LINE(); - -BENCHMARK_PARAM(subscribeAndUnsubscribe, 1000); -BENCHMARK_RELATIVE_PARAM(subscribe, 1000); -BENCHMARK_RELATIVE_PARAM(observe, 1000); -BENCHMARK_RELATIVE_PARAM(inlineObserve, 1000); - -BENCHMARK_DRAW_LINE(); - -BENCHMARK_PARAM(notifySubscribers, 1); -BENCHMARK_RELATIVE_PARAM(notifyInlineObservers, 1); - -BENCHMARK_DRAW_LINE(); - -BENCHMARK_PARAM(notifySubscribers, 1000); -BENCHMARK_RELATIVE_PARAM(notifyInlineObservers, 1000); - -int main(int argc, char** argv) { - gflags::ParseCommandLineFlags(&argc, &argv, true); - folly::runBenchmarks(); - return 0; -} diff --git a/folly/experimental/wangle/rx/test/RxTest.cpp b/folly/experimental/wangle/rx/test/RxTest.cpp deleted file mode 100644 index 8cf2605d..00000000 --- a/folly/experimental/wangle/rx/test/RxTest.cpp +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -using namespace folly::wangle; - -static std::unique_ptr> incrementer(int& counter) { - return Observer::create([&] (int x) { - counter++; - }); -} - -TEST(RxTest, Observe) { - Subject subject; - auto count = 0; - subject.observe(incrementer(count)); - subject.onNext(1); - EXPECT_EQ(1, count); -} - -TEST(RxTest, ObserveInline) { - Subject subject; - auto count = 0; - auto o = incrementer(count).release(); - subject.observe(o); - subject.onNext(1); - EXPECT_EQ(1, count); - delete o; -} - -TEST(RxTest, Subscription) { - Subject subject; - auto count = 0; - { - auto s = subject.subscribe(incrementer(count)); - subject.onNext(1); - } - // The subscription has gone out of scope so no one should get this. - subject.onNext(2); - EXPECT_EQ(1, count); -} - -TEST(RxTest, SubscriptionMove) { - Subject subject; - auto count = 0; - auto s = subject.subscribe(incrementer(count)); - auto s2 = subject.subscribe(incrementer(count)); - s2 = std::move(s); - subject.onNext(1); - Subscription s3(std::move(s2)); - subject.onNext(2); - EXPECT_EQ(2, count); -} - -TEST(RxTest, SubscriptionOutlivesSubject) { - Subscription s; - { - Subject subject; - s = subject.subscribe(Observer::create([](int){})); - } - // Don't explode when s is destroyed -} - -TEST(RxTest, SubscribeDuringCallback) { - // A subscriber who was subscribed in the course of a callback should get - // subsequent updates but not the current update. - Subject subject; - int outerCount = 0, innerCount = 0; - Subscription s1, s2; - s1 = subject.subscribe(Observer::create([&] (int x) { - outerCount++; - s2 = subject.subscribe(incrementer(innerCount)); - })); - subject.onNext(42); - subject.onNext(0xDEADBEEF); - EXPECT_EQ(2, outerCount); - EXPECT_EQ(1, innerCount); -} - -TEST(RxTest, ObserveDuringCallback) { - Subject subject; - int outerCount = 0, innerCount = 0; - subject.observe(Observer::create([&] (int x) { - outerCount++; - subject.observe(incrementer(innerCount)); - })); - subject.onNext(42); - subject.onNext(0xDEADBEEF); - EXPECT_EQ(2, outerCount); - EXPECT_EQ(1, innerCount); -} - -TEST(RxTest, ObserveInlineDuringCallback) { - Subject subject; - int outerCount = 0, innerCount = 0; - auto innerO = incrementer(innerCount).release(); - auto outerO = Observer::create([&] (int x) { - outerCount++; - subject.observe(innerO); - }).release(); - subject.observe(outerO); - subject.onNext(42); - subject.onNext(0xDEADBEEF); - EXPECT_EQ(2, outerCount); - EXPECT_EQ(1, innerCount); - delete innerO; - delete outerO; -} - -TEST(RxTest, UnsubscribeDuringCallback) { - // A subscriber who was unsubscribed in the course of a callback should get - // the current update but not subsequent ones - Subject subject; - int count1 = 0, count2 = 0; - auto s1 = subject.subscribe(incrementer(count1)); - auto s2 = subject.subscribe(Observer::create([&] (int x) { - count2++; - s1.~Subscription(); - })); - subject.onNext(1); - subject.onNext(2); - EXPECT_EQ(1, count1); - EXPECT_EQ(2, count2); -} - -TEST(RxTest, SubscribeUnsubscribeDuringCallback) { - // A subscriber who was subscribed and unsubscribed in the course of a - // callback should not get any updates - Subject subject; - int outerCount = 0, innerCount = 0; - auto s2 = subject.subscribe(Observer::create([&] (int x) { - outerCount++; - auto s2 = subject.subscribe(incrementer(innerCount)); - })); - subject.onNext(1); - subject.onNext(2); - EXPECT_EQ(2, outerCount); - EXPECT_EQ(0, innerCount); -} - -// Move only type -typedef std::unique_ptr MO; -static MO makeMO() { return folly::make_unique(1); } -template -static ObserverPtr makeMOObserver() { - return Observer::create([](const T& mo) { - EXPECT_EQ(1, *mo); - }); -} - -TEST(RxTest, MoveOnlyRvalue) { - Subject subject; - auto s1 = subject.subscribe(makeMOObserver()); - auto s2 = subject.subscribe(makeMOObserver()); - auto mo = makeMO(); - // Can't bind lvalues to rvalue references - // subject.onNext(mo); - subject.onNext(std::move(mo)); - subject.onNext(makeMO()); -} - -// Copy only type -struct CO { - CO() = default; - CO(const CO&) = default; - CO(CO&&) = delete; -}; - -template -static ObserverPtr makeCOObserver() { - return Observer::create([](const T& mo) {}); -} - -TEST(RxTest, CopyOnly) { - Subject subject; - auto s1 = subject.subscribe(makeCOObserver()); - CO co; - subject.onNext(co); -} diff --git a/folly/experimental/wangle/rx/types.h b/folly/experimental/wangle/rx/types.h deleted file mode 100644 index 27c2f3b7..00000000 --- a/folly/experimental/wangle/rx/types.h +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright 2014 Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace folly { namespace wangle { - typedef folly::exception_wrapper Error; - // The Executor is basically an rx Scheduler (by design). So just - // alias it. - typedef std::shared_ptr SchedulerPtr; - - template struct Observable; - template struct Observer; - template struct Subject; - - template using ObservablePtr = std::shared_ptr>; - template using ObserverPtr = std::shared_ptr>; - template using SubjectPtr = std::shared_ptr>; -}} diff --git a/folly/experimental/wangle/ssl/ClientHelloExtStats.h b/folly/experimental/wangle/ssl/ClientHelloExtStats.h deleted file mode 100644 index a95ee0c6..00000000 --- a/folly/experimental/wangle/ssl/ClientHelloExtStats.h +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -namespace folly { - -class ClientHelloExtStats { - public: - virtual ~ClientHelloExtStats() noexcept {} - - // client hello - virtual void recordAbsentHostname() noexcept = 0; - virtual void recordMatch() noexcept = 0; - virtual void recordNotMatch() noexcept = 0; -}; - -} diff --git a/folly/experimental/wangle/ssl/DHParam.h b/folly/experimental/wangle/ssl/DHParam.h deleted file mode 100644 index 561d5691..00000000 --- a/folly/experimental/wangle/ssl/DHParam.h +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -#include - -// The following was auto-generated by -// openssl dhparam -C 2048 -DH *get_dh2048() - { - static unsigned char dh2048_p[]={ - 0xF8,0x87,0xA5,0x15,0x98,0x35,0x20,0x1E,0xF5,0x81,0xE5,0x95, - 0x1B,0xE4,0x54,0xEA,0x53,0xF5,0xE7,0x26,0x30,0x03,0x06,0x79, - 0x3C,0xC1,0x0B,0xAD,0x3B,0x59,0x3C,0x61,0x13,0x03,0x7B,0x02, - 0x70,0xDE,0xC1,0x20,0x11,0x9E,0x94,0x13,0x50,0xF7,0x62,0xFC, - 0x99,0x0D,0xC1,0x12,0x6E,0x03,0x95,0xA3,0x57,0xC7,0x3C,0xB8, - 0x6B,0x40,0x56,0x65,0x70,0xFB,0x7A,0xE9,0x02,0xEC,0xD2,0xB6, - 0x54,0xD7,0x34,0xAD,0x3D,0x9E,0x11,0x61,0x53,0xBE,0xEA,0xB8, - 0x17,0x48,0xA8,0xDC,0x70,0xAE,0x65,0x99,0x3F,0x82,0x4C,0xFF, - 0x6A,0xC9,0xFA,0xB1,0xFA,0xE4,0x4F,0x5D,0xA4,0x05,0xC2,0x8E, - 0x55,0xC0,0xB1,0x1D,0xCC,0x17,0xF3,0xFA,0x65,0xD8,0x6B,0x09, - 0x13,0x01,0x2A,0x39,0xF1,0x86,0x73,0xE3,0x7A,0xC8,0xDB,0x7D, - 0xDA,0x1C,0xA1,0x2D,0xBA,0x2C,0x00,0x6B,0x2C,0x55,0x28,0x2B, - 0xD5,0xF5,0x3C,0x9F,0x50,0xA7,0xB7,0x28,0x9F,0x22,0xD5,0x3A, - 0xC4,0x53,0x01,0xC9,0xF3,0x69,0xB1,0x8D,0x01,0x36,0xF8,0xA8, - 0x89,0xCA,0x2E,0x72,0xBC,0x36,0x3A,0x42,0xC1,0x06,0xD6,0x0E, - 0xCB,0x4D,0x5C,0x1F,0xE4,0xA1,0x17,0xBF,0x55,0x64,0x1B,0xB4, - 0x52,0xEC,0x15,0xED,0x32,0xB1,0x81,0x07,0xC9,0x71,0x25,0xF9, - 0x4D,0x48,0x3D,0x18,0xF4,0x12,0x09,0x32,0xC4,0x0B,0x7A,0x4E, - 0x83,0xC3,0x10,0x90,0x51,0x2E,0xBE,0x87,0xF9,0xDE,0xB4,0xE6, - 0x3C,0x29,0xB5,0x32,0x01,0x9D,0x95,0x04,0xBD,0x42,0x89,0xFD, - 0x21,0xEB,0xE9,0x88,0x5A,0x27,0xBB,0x31,0xC4,0x26,0x99,0xAB, - 0x8C,0xA1,0x76,0xDB, - }; - static unsigned char dh2048_g[]={ - 0x02, - }; - DH *dh; - - if ((dh=DH_new()) == nullptr) return(nullptr); - dh->p=BN_bin2bn(dh2048_p,(int)sizeof(dh2048_p),nullptr); - dh->g=BN_bin2bn(dh2048_g,(int)sizeof(dh2048_g),nullptr); - if ((dh->p == nullptr) || (dh->g == nullptr)) - { DH_free(dh); return(nullptr); } - return(dh); - } diff --git a/folly/experimental/wangle/ssl/PasswordInFile.cpp b/folly/experimental/wangle/ssl/PasswordInFile.cpp deleted file mode 100644 index c876c39d..00000000 --- a/folly/experimental/wangle/ssl/PasswordInFile.cpp +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#include - -#include - -using namespace std; - -namespace folly { - -PasswordInFile::PasswordInFile(const string& file) - : fileName_(file) { - folly::readFile(file.c_str(), password_); - auto p = password_.find('\0'); - if (p != std::string::npos) { - password_.erase(p); - } -} - -PasswordInFile::~PasswordInFile() { - OPENSSL_cleanse((char *)password_.data(), password_.length()); -} - -} diff --git a/folly/experimental/wangle/ssl/PasswordInFile.h b/folly/experimental/wangle/ssl/PasswordInFile.h deleted file mode 100644 index b0a09227..00000000 --- a/folly/experimental/wangle/ssl/PasswordInFile.h +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -#include // PasswordCollector - -namespace folly { - -class PasswordInFile: public folly::PasswordCollector { - public: - explicit PasswordInFile(const std::string& file); - ~PasswordInFile(); - - void getPassword(std::string& password, int size) override { - password = password_; - } - - const char* getPasswordStr() const { - return password_.c_str(); - } - - std::string describe() const override { - return fileName_; - } - - protected: - std::string fileName_; - std::string password_; -}; - -} diff --git a/folly/experimental/wangle/ssl/SSLCacheOptions.h b/folly/experimental/wangle/ssl/SSLCacheOptions.h deleted file mode 100644 index 56175378..00000000 --- a/folly/experimental/wangle/ssl/SSLCacheOptions.h +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -#include -#include - -namespace folly { - -struct SSLCacheOptions { - std::chrono::seconds sslCacheTimeout; - uint64_t maxSSLCacheSize; - uint64_t sslCacheFlushSize; -}; - -} diff --git a/folly/experimental/wangle/ssl/SSLCacheProvider.h b/folly/experimental/wangle/ssl/SSLCacheProvider.h deleted file mode 100644 index feecca46..00000000 --- a/folly/experimental/wangle/ssl/SSLCacheProvider.h +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -#include - -namespace folly { - -class SSLSessionCacheManager; - -/** - * Interface to be implemented by providers of external session caches - */ -class SSLCacheProvider { -public: - /** - * Context saved during an external cache request that is used to - * resume the waiting client. - */ - typedef struct { - std::string sessionId; - SSL_SESSION* session; - SSLSessionCacheManager* manager; - AsyncSSLSocket* sslSocket; - std::unique_ptr< - folly::DelayedDestruction::DestructorGuard> guard; - } CacheContext; - - virtual ~SSLCacheProvider() {} - - /** - * Store a session in the external cache. - * @param sessionId Identifier that can be used later to fetch the - * session with getAsync() - * @param value Serialized session to store - * @param expiration Relative expiration time: seconds from now - * @return true if the storing of the session is initiated successfully - * (though not necessarily completed; the completion may - * happen either before or after this method returns), or - * false if the storing cannot be initiated due to an error. - */ - virtual bool setAsync(const std::string& sessionId, - const std::string& value, - std::chrono::seconds expiration) = 0; - - /** - * Retrieve a session from the external cache. When done, call - * the cache manager's onGetSuccess() or onGetFailure() callback. - * @param sessionId Session ID to fetch - * @param context Data to pass back to the SSLSessionCacheManager - * in the completion callback - * @return true if the lookup of the session is initiated successfully - * (though not necessarily completed; the completion may - * happen either before or after this method returns), or - * false if the lookup cannot be initiated due to an error. - */ - virtual bool getAsync(const std::string& sessionId, - CacheContext* context) = 0; - -}; - -} diff --git a/folly/experimental/wangle/ssl/SSLContextConfig.h b/folly/experimental/wangle/ssl/SSLContextConfig.h deleted file mode 100644 index bd3f8044..00000000 --- a/folly/experimental/wangle/ssl/SSLContextConfig.h +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -#include -#include -#include - -/** - * SSLContextConfig helps to describe the configs/options for - * a SSL_CTX. For example: - * - * 1. Filename of X509, private key and its password. - * 2. ciphers list - * 3. NPN list - * 4. Is session cache enabled? - * 5. Is it the default X509 in SNI operation? - * 6. .... and a few more - */ -namespace folly { - -struct SSLContextConfig { - SSLContextConfig() {} - ~SSLContextConfig() {} - - struct CertificateInfo { - std::string certPath; - std::string keyPath; - std::string passwordPath; - }; - - /** - * Helpers to set/add a certificate - */ - void setCertificate(const std::string& certPath, - const std::string& keyPath, - const std::string& passwordPath) { - certificates.clear(); - addCertificate(certPath, keyPath, passwordPath); - } - - void addCertificate(const std::string& certPath, - const std::string& keyPath, - const std::string& passwordPath) { - certificates.emplace_back(CertificateInfo{certPath, keyPath, passwordPath}); - } - - /** - * Set the optional list of protocols to advertise via TLS - * Next Protocol Negotiation. An empty list means NPN is not enabled. - */ - void setNextProtocols(const std::list& inNextProtocols) { - nextProtocols.clear(); - nextProtocols.push_back({1, inNextProtocols}); - } - - typedef std::function SNINoMatchFn; - - std::vector certificates; - folly::SSLContext::SSLVersion sslVersion{ - folly::SSLContext::TLSv1}; - bool sessionCacheEnabled{true}; - bool sessionTicketEnabled{true}; - bool clientHelloParsingEnabled{false}; - std::string sslCiphers{ - "ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:" - "ECDHE-ECDSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES128-GCM-SHA256:" - "ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-SHA:ECDHE-RSA-AES256-SHA:" - "AES128-GCM-SHA256:AES256-GCM-SHA384:AES128-SHA:AES256-SHA:" - "ECDHE-ECDSA-RC4-SHA:ECDHE-RSA-RC4-SHA:RC4-SHA:RC4-MD5:" - "ECDHE-RSA-DES-CBC3-SHA:DES-CBC3-SHA"}; - std::string eccCurveName; - // Ciphers to negotiate if TLS version >= 1.1 - std::string tls11Ciphers{""}; - // Weighted lists of NPN strings to advertise - std::list - nextProtocols; - bool isLocalPrivateKey{true}; - // Should this SSLContextConfig be the default for SNI purposes - bool isDefault{false}; - // Callback function to invoke when there are no matching certificates - // (will only be invoked once) - SNINoMatchFn sniNoMatchFn; - // File containing trusted CA's to validate client certificates - std::string clientCAFile; -}; - -} diff --git a/folly/experimental/wangle/ssl/SSLContextManager.cpp b/folly/experimental/wangle/ssl/SSLContextManager.cpp deleted file mode 100644 index eb9f1266..00000000 --- a/folly/experimental/wangle/ssl/SSLContextManager.cpp +++ /dev/null @@ -1,651 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#define OPENSSL_MISSING_FEATURE(name) \ -do { \ - throw std::runtime_error("missing " #name " support in openssl"); \ -} while(0) - - -using std::string; -using std::shared_ptr; - -/** - * SSLContextManager helps to create and manage all SSL_CTX, - * SSLSessionCacheManager and TLSTicketManager for a listening - * VIP:PORT. (Note, in SNI, a listening VIP:PORT can have >1 SSL_CTX(s)). - * - * Other responsibilities: - * 1. It also handles the SSL_CTX selection after getting the tlsext_hostname - * in the client hello message. - * - * Usage: - * 1. Each listening VIP:PORT serving SSL should have one SSLContextManager. - * It maps to Acceptor in the wangle vocabulary. - * - * 2. Create a SSLContextConfig object (e.g. by parsing the JSON config). - * - * 3. Call SSLContextManager::addSSLContextConfig() which will - * then create and configure the SSL_CTX - * - * Note: Each Acceptor, with SSL support, should have one SSLContextManager to - * manage all SSL_CTX for the VIP:PORT. - */ - -namespace folly { - -namespace { - -X509* getX509(SSL_CTX* ctx) { - SSL* ssl = SSL_new(ctx); - SSL_set_connect_state(ssl); - X509* x509 = SSL_get_certificate(ssl); - CRYPTO_add(&x509->references, 1, CRYPTO_LOCK_X509); - SSL_free(ssl); - return x509; -} - -void set_key_from_curve(SSL_CTX* ctx, const std::string& curveName) { -#if OPENSSL_VERSION_NUMBER >= 0x0090800fL -#ifndef OPENSSL_NO_ECDH - EC_KEY* ecdh = nullptr; - int nid; - - /* - * Elliptic-Curve Diffie-Hellman parameters are either "named curves" - * from RFC 4492 section 5.1.1, or explicitly described curves over - * binary fields. OpenSSL only supports the "named curves", which provide - * maximum interoperability. - */ - - nid = OBJ_sn2nid(curveName.c_str()); - if (nid == 0) { - LOG(FATAL) << "Unknown curve name:" << curveName.c_str(); - return; - } - ecdh = EC_KEY_new_by_curve_name(nid); - if (ecdh == nullptr) { - LOG(FATAL) << "Unable to create curve:" << curveName.c_str(); - return; - } - - SSL_CTX_set_tmp_ecdh(ctx, ecdh); - EC_KEY_free(ecdh); -#endif -#endif -} - -// Helper to create TLSTicketKeyManger and aware of the needed openssl -// version/feature. -std::unique_ptr createTicketManagerHelper( - std::shared_ptr ctx, - const TLSTicketKeySeeds* ticketSeeds, - const SSLContextConfig& ctxConfig, - SSLStats* stats) { - - std::unique_ptr ticketManager; -#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB - if (ticketSeeds && ctxConfig.sessionTicketEnabled) { - ticketManager = folly::make_unique(ctx.get(), stats); - ticketManager->setTLSTicketKeySeeds( - ticketSeeds->oldSeeds, - ticketSeeds->currentSeeds, - ticketSeeds->newSeeds); - } else { - ctx->setOptions(SSL_OP_NO_TICKET); - } -#else - if (ticketSeeds && ctxConfig.sessionTicketEnabled) { - OPENSSL_MISSING_FEATURE(TLSTicket); - } -#endif - return ticketManager; -} - -std::string flattenList(const std::list& list) { - std::string s; - bool first = true; - for (auto& item : list) { - if (first) { - first = false; - } else { - s.append(", "); - } - s.append(item); - } - return s; -} - -} - -SSLContextManager::~SSLContextManager() {} - -SSLContextManager::SSLContextManager( - EventBase* eventBase, - const std::string& vipName, - bool strict, - SSLStats* stats) : - stats_(stats), - eventBase_(eventBase), - strict_(strict) { -} - -void SSLContextManager::addSSLContextConfig( - const SSLContextConfig& ctxConfig, - const SSLCacheOptions& cacheOptions, - const TLSTicketKeySeeds* ticketSeeds, - const folly::SocketAddress& vipAddress, - const std::shared_ptr& externalCache) { - - unsigned numCerts = 0; - std::string commonName; - std::string lastCertPath; - std::unique_ptr> subjectAltName; - auto sslCtx = std::make_shared(ctxConfig.sslVersion); - for (const auto& cert : ctxConfig.certificates) { - try { - sslCtx->loadCertificate(cert.certPath.c_str()); - } catch (const std::exception& ex) { - // The exception isn't very useful without the certificate path name, - // so throw a new exception that includes the path to the certificate. - string msg = folly::to("error loading SSL certificate ", - cert.certPath, ": ", - folly::exceptionStr(ex)); - LOG(ERROR) << msg; - throw std::runtime_error(msg); - } - - // Verify that the Common Name and (if present) Subject Alternative Names - // are the same for all the certs specified for the SSL context. - numCerts++; - X509* x509 = getX509(sslCtx->getSSLCtx()); - auto guard = folly::makeGuard([x509] { X509_free(x509); }); - auto cn = SSLUtil::getCommonName(x509); - if (!cn) { - throw std::runtime_error(folly::to("Cannot get CN for X509 ", - cert.certPath)); - } - auto altName = SSLUtil::getSubjectAltName(x509); - VLOG(2) << "cert " << cert.certPath << " CN: " << *cn; - if (altName) { - altName->sort(); - VLOG(2) << "cert " << cert.certPath << " SAN: " << flattenList(*altName); - } else { - VLOG(2) << "cert " << cert.certPath << " SAN: " << "{none}"; - } - if (numCerts == 1) { - commonName = *cn; - subjectAltName = std::move(altName); - } else { - if (commonName != *cn) { - throw std::runtime_error(folly::to("X509 ", cert.certPath, - " does not have same CN as ", - lastCertPath)); - } - if (altName == nullptr) { - if (subjectAltName != nullptr) { - throw std::runtime_error(folly::to("X509 ", cert.certPath, - " does not have same SAN as ", - lastCertPath)); - } - } else { - if ((subjectAltName == nullptr) || (*altName != *subjectAltName)) { - throw std::runtime_error(folly::to("X509 ", cert.certPath, - " does not have same SAN as ", - lastCertPath)); - } - } - } - lastCertPath = cert.certPath; - - // TODO t4438250 - Add ECDSA support to the crypto_ssl offload server - // so we can avoid storing the ECDSA private key in the - // address space of the Internet-facing process. For - // now, if cert name includes "-EC" to denote elliptic - // curve, we load its private key even if the server as - // a whole has been configured for async crypto. - if (ctxConfig.isLocalPrivateKey || - (cert.certPath.find("-EC") != std::string::npos)) { - // The private key lives in the same process - - // This needs to be called before loadPrivateKey(). - if (!cert.passwordPath.empty()) { - auto sslPassword = std::make_shared(cert.passwordPath); - sslCtx->passwordCollector(sslPassword); - } - - try { - sslCtx->loadPrivateKey(cert.keyPath.c_str()); - } catch (const std::exception& ex) { - // Throw an error that includes the key path, so the user can tell - // which key had a problem. - string msg = folly::to("error loading private SSL key ", - cert.keyPath, ": ", - folly::exceptionStr(ex)); - LOG(ERROR) << msg; - throw std::runtime_error(msg); - } - } - } - if (!ctxConfig.isLocalPrivateKey) { - enableAsyncCrypto(sslCtx); - } - - // Let the server pick the highest performing cipher from among the client's - // choices. - // - // Let's use a unique private key for all DH key exchanges. - // - // Because some old implementations choke on empty fragments, most SSL - // applications disable them (it's part of SSL_OP_ALL). This - // will improve performance and decrease write buffer fragmentation. - sslCtx->setOptions(SSL_OP_CIPHER_SERVER_PREFERENCE | - SSL_OP_SINGLE_DH_USE | - SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS); - - // Configure SSL ciphers list - if (!ctxConfig.tls11Ciphers.empty()) { - // FIXME: create a dummy SSL_CTX for cipher testing purpose? It can - // remove the ordering dependency - - // Test to see if the specified TLS1.1 ciphers are valid. Note that - // these will be overwritten by the ciphers() call below. - sslCtx->setCiphersOrThrow(ctxConfig.tls11Ciphers); - } - - // Important that we do this *after* checking the TLS1.1 ciphers above, - // since we test their validity by actually setting them. - sslCtx->ciphers(ctxConfig.sslCiphers); - - // Use a fix DH param - DH* dh = get_dh2048(); - SSL_CTX_set_tmp_dh(sslCtx->getSSLCtx(), dh); - DH_free(dh); - - const string& curve = ctxConfig.eccCurveName; - if (!curve.empty()) { - set_key_from_curve(sslCtx->getSSLCtx(), curve); - } - - if (!ctxConfig.clientCAFile.empty()) { - try { - sslCtx->setVerificationOption(SSLContext::VERIFY_REQ_CLIENT_CERT); - sslCtx->loadTrustedCertificates(ctxConfig.clientCAFile.c_str()); - sslCtx->loadClientCAList(ctxConfig.clientCAFile.c_str()); - } catch (const std::exception& ex) { - string msg = folly::to("error loading client CA", - ctxConfig.clientCAFile, ": ", - folly::exceptionStr(ex)); - LOG(ERROR) << msg; - throw std::runtime_error(msg); - } - } - - // - start - SSL session cache config - // the internal cache never does what we want (per-thread-per-vip). - // Disable it. SSLSessionCacheManager will set it appropriately. - SSL_CTX_set_session_cache_mode(sslCtx->getSSLCtx(), SSL_SESS_CACHE_OFF); - SSL_CTX_set_timeout(sslCtx->getSSLCtx(), - cacheOptions.sslCacheTimeout.count()); - std::unique_ptr sessionCacheManager; - if (ctxConfig.sessionCacheEnabled && - cacheOptions.maxSSLCacheSize > 0 && - cacheOptions.sslCacheFlushSize > 0) { - sessionCacheManager = - folly::make_unique( - cacheOptions.maxSSLCacheSize, - cacheOptions.sslCacheFlushSize, - sslCtx.get(), - vipAddress, - commonName, - eventBase_, - stats_, - externalCache); - } - // - end - SSL session cache config - - std::unique_ptr ticketManager = - createTicketManagerHelper(sslCtx, ticketSeeds, ctxConfig, stats_); - - // finalize sslCtx setup by the individual features supported by openssl - ctxSetupByOpensslFeature(sslCtx, ctxConfig); - - try { - insert(sslCtx, - std::move(sessionCacheManager), - std::move(ticketManager), - ctxConfig.isDefault); - } catch (const std::exception& ex) { - string msg = folly::to("Error adding certificate : ", - folly::exceptionStr(ex)); - LOG(ERROR) << msg; - throw std::runtime_error(msg); - } - -} - -#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK -SSLContext::ServerNameCallbackResult -SSLContextManager::serverNameCallback(SSL* ssl) { - shared_ptr ctx; - - const char* sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); - if (!sn) { - VLOG(6) << "Server Name (tlsext_hostname) is missing"; - if (clientHelloTLSExtStats_) { - clientHelloTLSExtStats_->recordAbsentHostname(); - } - return SSLContext::SERVER_NAME_NOT_FOUND; - } - size_t snLen = strlen(sn); - VLOG(6) << "Server Name (SNI TLS extension): '" << sn << "' "; - - // FIXME: This code breaks the abstraction. Suggestion? - AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl); - CHECK(sslSocket); - - DNString dnstr(sn, snLen); - - uint32_t count = 0; - do { - // Try exact match first - ctx = getSSLCtx(dnstr); - if (ctx) { - sslSocket->switchServerSSLContext(ctx); - if (clientHelloTLSExtStats_) { - clientHelloTLSExtStats_->recordMatch(); - } - return SSLContext::SERVER_NAME_FOUND; - } - - ctx = getSSLCtxBySuffix(dnstr); - if (ctx) { - sslSocket->switchServerSSLContext(ctx); - if (clientHelloTLSExtStats_) { - clientHelloTLSExtStats_->recordMatch(); - } - return SSLContext::SERVER_NAME_FOUND; - } - - // Give the noMatchFn one chance to add the correct cert - } - while (count++ == 0 && noMatchFn_ && noMatchFn_(sn)); - - VLOG(6) << folly::stringPrintf("Cannot find a SSL_CTX for \"%s\"", sn); - - if (clientHelloTLSExtStats_) { - clientHelloTLSExtStats_->recordNotMatch(); - } - return SSLContext::SERVER_NAME_NOT_FOUND; -} -#endif - -// Consolidate all SSL_CTX setup which depends on openssl version/feature -void -SSLContextManager::ctxSetupByOpensslFeature( - shared_ptr sslCtx, - const SSLContextConfig& ctxConfig) { - // Disable compression - profiling shows this to be very expensive in - // terms of CPU and memory consumption. - // -#ifdef SSL_OP_NO_COMPRESSION - sslCtx->setOptions(SSL_OP_NO_COMPRESSION); -#endif - - // Enable early release of SSL buffers to reduce the memory footprint -#ifdef SSL_MODE_RELEASE_BUFFERS - sslCtx->getSSLCtx()->mode |= SSL_MODE_RELEASE_BUFFERS; -#endif -#ifdef SSL_MODE_EARLY_RELEASE_BBIO - sslCtx->getSSLCtx()->mode |= SSL_MODE_EARLY_RELEASE_BBIO; -#endif - - // This number should (probably) correspond to HTTPSession::kMaxReadSize - // For now, this number must also be large enough to accommodate our - // largest certificate, because some older clients (IE6/7) require the - // cert to be in a single fragment. -#ifdef SSL_CTRL_SET_MAX_SEND_FRAGMENT - SSL_CTX_set_max_send_fragment(sslCtx->getSSLCtx(), 8000); -#endif - - // Specify cipher(s) to be used for TLS1.1 client - if (!ctxConfig.tls11Ciphers.empty()) { -#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK - // Specified TLS1.1 ciphers are valid - sslCtx->addClientHelloCallback( - std::bind( - &SSLContext::switchCiphersIfTLS11, - sslCtx.get(), - std::placeholders::_1, - ctxConfig.tls11Ciphers - ) - ); -#else - OPENSSL_MISSING_FEATURE(SNI); -#endif - } - - // NPN (Next Protocol Negotiation) - if (!ctxConfig.nextProtocols.empty()) { -#ifdef OPENSSL_NPN_NEGOTIATED - sslCtx->setRandomizedAdvertisedNextProtocols(ctxConfig.nextProtocols); -#else - OPENSSL_MISSING_FEATURE(NPN); -#endif - } - - // SNI -#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK - noMatchFn_ = ctxConfig.sniNoMatchFn; - if (ctxConfig.isDefault) { - if (defaultCtx_) { - throw std::runtime_error(">1 X509 is set as default"); - } - - defaultCtx_ = sslCtx; - defaultCtx_->setServerNameCallback( - std::bind(&SSLContextManager::serverNameCallback, this, - std::placeholders::_1)); - } -#else - if (ctxs_.size() > 1) { - OPENSSL_MISSING_FEATURE(SNI); - } -#endif -} - -void -SSLContextManager::insert(shared_ptr sslCtx, - std::unique_ptr smanager, - std::unique_ptr tmanager, - bool defaultFallback) { - X509* x509 = getX509(sslCtx->getSSLCtx()); - auto guard = folly::makeGuard([x509] { X509_free(x509); }); - auto cn = SSLUtil::getCommonName(x509); - if (!cn) { - throw std::runtime_error("Cannot get CN"); - } - - /** - * Some notes from RFC 2818. Only for future quick references in case of bugs - * - * RFC 2818 section 3.1: - * "...... - * If a subjectAltName extension of type dNSName is present, that MUST - * be used as the identity. Otherwise, the (most specific) Common Name - * field in the Subject field of the certificate MUST be used. Although - * the use of the Common Name is existing practice, it is deprecated and - * Certification Authorities are encouraged to use the dNSName instead. - * ...... - * In some cases, the URI is specified as an IP address rather than a - * hostname. In this case, the iPAddress subjectAltName must be present - * in the certificate and must exactly match the IP in the URI. - * ......" - */ - - // Not sure if we ever get this kind of X509... - // If we do, assume '*' is always in the CN and ignore all subject alternative - // names. - if (cn->length() == 1 && (*cn)[0] == '*') { - if (!defaultFallback) { - throw std::runtime_error("STAR X509 is not the default"); - } - ctxs_.emplace_back(sslCtx); - sessionCacheManagers_.emplace_back(std::move(smanager)); - ticketManagers_.emplace_back(std::move(tmanager)); - return; - } - - // Insert by CN - insertSSLCtxByDomainName(cn->c_str(), cn->length(), sslCtx); - - // Insert by subject alternative name(s) - auto altNames = SSLUtil::getSubjectAltName(x509); - if (altNames) { - for (auto& name : *altNames) { - insertSSLCtxByDomainName(name.c_str(), name.length(), sslCtx); - } - } - - ctxs_.emplace_back(sslCtx); - sessionCacheManagers_.emplace_back(std::move(smanager)); - ticketManagers_.emplace_back(std::move(tmanager)); -} - -void -SSLContextManager::insertSSLCtxByDomainName(const char* dn, size_t len, - shared_ptr sslCtx) { - try { - insertSSLCtxByDomainNameImpl(dn, len, sslCtx); - } catch (const std::runtime_error& ex) { - if (strict_) { - throw ex; - } else { - LOG(ERROR) << ex.what() << " DN=" << dn; - } - } -} -void -SSLContextManager::insertSSLCtxByDomainNameImpl(const char* dn, size_t len, - shared_ptr sslCtx) -{ - VLOG(4) << - folly::stringPrintf("Adding CN/Subject-alternative-name \"%s\" for " - "SNI search", dn); - - // Only support wildcard domains which are prefixed exactly by "*." . - // "*" appearing at other locations is not accepted. - - if (len > 2 && dn[0] == '*') { - if (dn[1] == '.') { - // skip the first '*' - dn++; - len--; - } else { - throw std::runtime_error( - "Invalid wildcard CN/subject-alternative-name \"" + std::string(dn) + "\" " - "(only allow character \".\" after \"*\""); - } - } - - if (len == 1 && *dn == '.') { - throw std::runtime_error("X509 has only '.' in the CN or subject alternative name " - "(after removing any preceding '*')"); - } - - if (strchr(dn, '*')) { - throw std::runtime_error("X509 has '*' in the the CN or subject alternative name " - "(after removing any preceding '*')"); - } - - DNString dnstr(dn, len); - const auto v = dnMap_.find(dnstr); - if (v == dnMap_.end()) { - dnMap_.emplace(dnstr, sslCtx); - } else if (v->second == sslCtx) { - VLOG(6)<< "Duplicate CN or subject alternative name found in the same X509." - " Ignore the later name."; - } else { - throw std::runtime_error("Duplicate CN or subject alternative name found: \"" + - std::string(dnstr.c_str()) + "\""); - } -} - -shared_ptr -SSLContextManager::getSSLCtxBySuffix(const DNString& dnstr) const -{ - size_t dot; - - if ((dot = dnstr.find_first_of(".")) != DNString::npos) { - DNString suffixDNStr(dnstr, dot); - const auto v = dnMap_.find(suffixDNStr); - if (v != dnMap_.end()) { - VLOG(6) << folly::stringPrintf("\"%s\" is a willcard match to \"%s\"", - dnstr.c_str(), suffixDNStr.c_str()); - return v->second; - } - } - - VLOG(6) << folly::stringPrintf("\"%s\" is not a wildcard match", - dnstr.c_str()); - return shared_ptr(); -} - -shared_ptr -SSLContextManager::getSSLCtx(const DNString& dnstr) const -{ - const auto v = dnMap_.find(dnstr); - if (v == dnMap_.end()) { - VLOG(6) << folly::stringPrintf("\"%s\" is not an exact match", - dnstr.c_str()); - return shared_ptr(); - } else { - VLOG(6) << folly::stringPrintf("\"%s\" is an exact match", dnstr.c_str()); - return v->second; - } -} - -shared_ptr -SSLContextManager::getDefaultSSLCtx() const { - return defaultCtx_; -} - -void -SSLContextManager::reloadTLSTicketKeys( - const std::vector& oldSeeds, - const std::vector& currentSeeds, - const std::vector& newSeeds) { -#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB - for (auto& tmgr: ticketManagers_) { - tmgr->setTLSTicketKeySeeds(oldSeeds, currentSeeds, newSeeds); - } -#endif -} - -} // namespace diff --git a/folly/experimental/wangle/ssl/SSLContextManager.h b/folly/experimental/wangle/ssl/SSLContextManager.h deleted file mode 100644 index 26506493..00000000 --- a/folly/experimental/wangle/ssl/SSLContextManager.h +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace folly { - -class SocketAddress; -class SSLContext; -class ClientHelloExtStats; -class SSLCacheOptions; -class SSLStats; -class TLSTicketKeyManager; -class TLSTicketKeySeeds; - -class SSLContextManager { - public: - - explicit SSLContextManager(EventBase* eventBase, - const std::string& vipName, bool strict, - SSLStats* stats); - virtual ~SSLContextManager(); - - /** - * Add a new X509 to SSLContextManager. The details of a X509 - * is passed as a SSLContextConfig object. - * - * @param ctxConfig Details of a X509, its private key, password, etc. - * @param cacheOptions Options for how to do session caching. - * @param ticketSeeds If non-null, the initial ticket key seeds to use. - * @param vipAddress Which VIP are the X509(s) used for? It is only for - * for user friendly log message - * @param externalCache Optional external provider for the session cache; - * may be null - */ - void addSSLContextConfig( - const SSLContextConfig& ctxConfig, - const SSLCacheOptions& cacheOptions, - const TLSTicketKeySeeds* ticketSeeds, - const folly::SocketAddress& vipAddress, - const std::shared_ptr &externalCache); - - /** - * Get the default SSL_CTX for a VIP - */ - std::shared_ptr - getDefaultSSLCtx() const; - - /** - * Search by the _one_ level up subdomain - */ - std::shared_ptr - getSSLCtxBySuffix(const DNString& dnstr) const; - - /** - * Search by the full-string domain name - */ - std::shared_ptr - getSSLCtx(const DNString& dnstr) const; - - /** - * Insert a SSLContext by domain name. - */ - void insertSSLCtxByDomainName( - const char* dn, - size_t len, - std::shared_ptr sslCtx); - - void insertSSLCtxByDomainNameImpl( - const char* dn, - size_t len, - std::shared_ptr sslCtx); - - void reloadTLSTicketKeys(const std::vector& oldSeeds, - const std::vector& currentSeeds, - const std::vector& newSeeds); - - /** - * SSLContextManager only collects SNI stats now - */ - - void setClientHelloExtStats(ClientHelloExtStats* stats) { - clientHelloTLSExtStats_ = stats; - } - - protected: - virtual void enableAsyncCrypto( - const std::shared_ptr& sslCtx) { - LOG(FATAL) << "Unsupported in base SSLContextManager"; - } - SSLStats* stats_{nullptr}; - - private: - SSLContextManager(const SSLContextManager&) = delete; - - void ctxSetupByOpensslFeature( - std::shared_ptr sslCtx, - const SSLContextConfig& ctxConfig); - - /** - * Callback function from openssl to find the right X509 to - * use during SSL handshake - */ -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && \ - !defined(OPENSSL_NO_TLSEXT) && \ - defined(SSL_CTRL_SET_TLSEXT_SERVERNAME_CB) -# define PROXYGEN_HAVE_SERVERNAMECALLBACK - SSLContext::ServerNameCallbackResult - serverNameCallback(SSL* ssl); -#endif - - /** - * The following functions help to maintain the data structure for - * domain name matching in SNI. Some notes: - * - * 1. It is a best match. - * - * 2. It allows wildcard CN and wildcard subject alternative name in a X509. - * The wildcard name must be _prefixed_ by '*.'. It errors out whenever - * it sees '*' in any other locations. - * - * 3. It uses one std::unordered_map object to - * do this. For wildcard name like "*.facebook.com", ".facebook.com" - * is used as the key. - * - * 4. After getting tlsext_hostname from the client hello message, it - * will do a full string search first and then try one level up to - * match any wildcard name (if any) in the X509. - * [Note, browser also only looks one level up when matching the requesting - * domain name with the wildcard name in the server X509]. - */ - - void insert( - std::shared_ptr sslCtx, - std::unique_ptr cmanager, - std::unique_ptr tManager, - bool defaultFallback); - - /** - * Container to own the SSLContext, SSLSessionCacheManager and - * TLSTicketKeyManager. - */ - std::vector> ctxs_; - std::vector> - sessionCacheManagers_; - std::vector> ticketManagers_; - - std::shared_ptr defaultCtx_; - - /** - * Container to store the (DomainName -> SSL_CTX) mapping - */ - std::unordered_map< - DNString, - std::shared_ptr, - DNStringHash> dnMap_; - - EventBase* eventBase_; - ClientHelloExtStats* clientHelloTLSExtStats_{nullptr}; - SSLContextConfig::SNINoMatchFn noMatchFn_; - bool strict_{true}; -}; - -} // namespace diff --git a/folly/experimental/wangle/ssl/SSLSessionCacheManager.cpp b/folly/experimental/wangle/ssl/SSLSessionCacheManager.cpp deleted file mode 100644 index e9111b6b..00000000 --- a/folly/experimental/wangle/ssl/SSLSessionCacheManager.cpp +++ /dev/null @@ -1,354 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#include - -#include -#include -#include - -#include - -#ifndef NO_LIB_GFLAGS -#include -#endif - -using std::string; -using std::shared_ptr; - -namespace { - -const uint32_t NUM_CACHE_BUCKETS = 16; - -// We use the default ID generator which fills the maximum ID length -// for the protocol. 16 bytes for SSLv2 or 32 for SSLv3+ -const int MIN_SESSION_ID_LENGTH = 16; - -} - -#ifndef NO_LIB_GFLAGS -DEFINE_bool(dcache_unit_test, false, "All VIPs share one session cache"); -#else -const bool FLAGS_dcache_unit_test = false; -#endif - -namespace folly { - - -int SSLSessionCacheManager::sExDataIndex_ = -1; -shared_ptr SSLSessionCacheManager::sCache_; -std::mutex SSLSessionCacheManager::sCacheLock_; - -LocalSSLSessionCache::LocalSSLSessionCache(uint32_t maxCacheSize, - uint32_t cacheCullSize) - : sessionCache(maxCacheSize, cacheCullSize) { - sessionCache.setPruneHook(std::bind( - &LocalSSLSessionCache::pruneSessionCallback, - this, std::placeholders::_1, - std::placeholders::_2)); -} - -void LocalSSLSessionCache::pruneSessionCallback(const string& sessionId, - SSL_SESSION* session) { - VLOG(4) << "Free SSL session from local cache; id=" - << SSLUtil::hexlify(sessionId); - SSL_SESSION_free(session); - ++removedSessions_; -} - - -// SSLSessionCacheManager implementation - -SSLSessionCacheManager::SSLSessionCacheManager( - uint32_t maxCacheSize, - uint32_t cacheCullSize, - SSLContext* ctx, - const folly::SocketAddress& sockaddr, - const string& context, - EventBase* eventBase, - SSLStats* stats, - const std::shared_ptr& externalCache): - ctx_(ctx), - stats_(stats), - externalCache_(externalCache) { - - SSL_CTX* sslCtx = ctx->getSSLCtx(); - - SSLUtil::getSSLCtxExIndex(&sExDataIndex_); - - SSL_CTX_set_ex_data(sslCtx, sExDataIndex_, this); - SSL_CTX_sess_set_new_cb(sslCtx, SSLSessionCacheManager::newSessionCallback); - SSL_CTX_sess_set_get_cb(sslCtx, SSLSessionCacheManager::getSessionCallback); - SSL_CTX_sess_set_remove_cb(sslCtx, - SSLSessionCacheManager::removeSessionCallback); - if (!FLAGS_dcache_unit_test && !context.empty()) { - // Use the passed in context - SSL_CTX_set_session_id_context(sslCtx, (const uint8_t *)context.data(), - std::min((int)context.length(), - SSL_MAX_SSL_SESSION_ID_LENGTH)); - } - - SSL_CTX_set_session_cache_mode(sslCtx, SSL_SESS_CACHE_NO_INTERNAL - | SSL_SESS_CACHE_SERVER); - - localCache_ = SSLSessionCacheManager::getLocalCache(maxCacheSize, - cacheCullSize); - - VLOG(2) << "On VipID=" << sockaddr.describe() << " context=" << context; -} - -SSLSessionCacheManager::~SSLSessionCacheManager() { -} - -void SSLSessionCacheManager::shutdown() { - std::lock_guard g(sCacheLock_); - sCache_.reset(); -} - -shared_ptr SSLSessionCacheManager::getLocalCache( - uint32_t maxCacheSize, - uint32_t cacheCullSize) { - - std::lock_guard g(sCacheLock_); - if (!sCache_) { - sCache_.reset(new ShardedLocalSSLSessionCache(NUM_CACHE_BUCKETS, - maxCacheSize, - cacheCullSize)); - } - return sCache_; -} - -int SSLSessionCacheManager::newSessionCallback(SSL* ssl, SSL_SESSION* session) { - SSLSessionCacheManager* manager = nullptr; - SSL_CTX* ctx = SSL_get_SSL_CTX(ssl); - manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_); - - if (manager == nullptr) { - LOG(FATAL) << "Null SSLSessionCacheManager in callback"; - return -1; - } - return manager->newSession(ssl, session); -} - - -int SSLSessionCacheManager::newSession(SSL* ssl, SSL_SESSION* session) { - string sessionId((char*)session->session_id, session->session_id_length); - VLOG(4) << "New SSL session; id=" << SSLUtil::hexlify(sessionId); - - if (stats_) { - stats_->recordSSLSession(true /* new session */, false, false); - } - - localCache_->storeSession(sessionId, session, stats_); - - if (externalCache_) { - VLOG(4) << "New SSL session: send session to external cache; id=" << - SSLUtil::hexlify(sessionId); - storeCacheRecord(sessionId, session); - } - - return 1; -} - -void SSLSessionCacheManager::removeSessionCallback(SSL_CTX* ctx, - SSL_SESSION* session) { - SSLSessionCacheManager* manager = nullptr; - manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_); - - if (manager == nullptr) { - LOG(FATAL) << "Null SSLSessionCacheManager in callback"; - return; - } - return manager->removeSession(ctx, session); -} - -void SSLSessionCacheManager::removeSession(SSL_CTX* ctx, - SSL_SESSION* session) { - string sessionId((char*)session->session_id, session->session_id_length); - - // This hook is only called from SSL when the internal session cache needs to - // flush sessions. Since we run with the internal cache disabled, this should - // never be called - VLOG(3) << "Remove SSL session; id=" << SSLUtil::hexlify(sessionId); - - localCache_->removeSession(sessionId); - - if (stats_) { - stats_->recordSSLSessionRemove(); - } -} - -SSL_SESSION* SSLSessionCacheManager::getSessionCallback(SSL* ssl, - unsigned char* sess_id, - int id_len, - int* copyflag) { - SSLSessionCacheManager* manager = nullptr; - SSL_CTX* ctx = SSL_get_SSL_CTX(ssl); - manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_); - - if (manager == nullptr) { - LOG(FATAL) << "Null SSLSessionCacheManager in callback"; - return nullptr; - } - return manager->getSession(ssl, sess_id, id_len, copyflag); -} - -SSL_SESSION* SSLSessionCacheManager::getSession(SSL* ssl, - unsigned char* session_id, - int id_len, - int* copyflag) { - VLOG(7) << "SSL get session callback"; - SSL_SESSION* session = nullptr; - bool foreign = false; - char const* missReason = nullptr; - - if (id_len < MIN_SESSION_ID_LENGTH) { - // We didn't generate this session so it's going to be a miss. - // This doesn't get logged or counted in the stats. - return nullptr; - } - string sessionId((char*)session_id, id_len); - - AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl); - - assert(sslSocket != nullptr); - - // look it up in the local cache first - session = localCache_->lookupSession(sessionId); -#ifdef SSL_SESSION_CB_WOULD_BLOCK - if (session == nullptr && externalCache_) { - // external cache might have the session - foreign = true; - if (!SSL_want_sess_cache_lookup(ssl)) { - missReason = "reason: No async cache support;"; - } else { - PendingLookupMap::iterator pit = pendingLookups_.find(sessionId); - if (pit == pendingLookups_.end()) { - auto result = pendingLookups_.emplace(sessionId, PendingLookup()); - // initiate fetch - VLOG(4) << "Get SSL session [Pending]: Initiate Fetch; fd=" << - sslSocket->getFd() << " id=" << SSLUtil::hexlify(sessionId); - if (lookupCacheRecord(sessionId, sslSocket)) { - // response is pending - *copyflag = SSL_SESSION_CB_WOULD_BLOCK; - return nullptr; - } else { - missReason = "reason: failed to send lookup request;"; - pendingLookups_.erase(result.first); - } - } else { - // A lookup was already initiated from this thread - if (pit->second.request_in_progress) { - // Someone else initiated the request, attach - VLOG(4) << "Get SSL session [Pending]: Request in progess: attach; " - "fd=" << sslSocket->getFd() << " id=" << - SSLUtil::hexlify(sessionId); - std::unique_ptr dg( - new DelayedDestruction::DestructorGuard(sslSocket)); - pit->second.waiters.push_back( - std::make_pair(sslSocket, std::move(dg))); - *copyflag = SSL_SESSION_CB_WOULD_BLOCK; - return nullptr; - } - // request is complete - session = pit->second.session; // nullptr if our friend didn't have it - if (session != nullptr) { - CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION); - } - } - } - } -#endif - - bool hit = (session != nullptr); - if (stats_) { - stats_->recordSSLSession(false, hit, foreign); - } - if (hit) { - sslSocket->setSessionIDResumed(true); - } - - VLOG(4) << "Get SSL session [" << - ((hit) ? "Hit" : "Miss") << "]: " << - ((foreign) ? "external" : "local") << " cache; " << - ((missReason != nullptr) ? missReason : "") << "fd=" << - sslSocket->getFd() << " id=" << SSLUtil::hexlify(sessionId); - - // We already bumped the refcount - *copyflag = 0; - - return session; -} - -bool SSLSessionCacheManager::storeCacheRecord(const string& sessionId, - SSL_SESSION* session) { - std::string sessionString; - uint32_t sessionLen = i2d_SSL_SESSION(session, nullptr); - sessionString.resize(sessionLen); - uint8_t* cp = (uint8_t *)sessionString.data(); - i2d_SSL_SESSION(session, &cp); - size_t expiration = SSL_CTX_get_timeout(ctx_->getSSLCtx()); - return externalCache_->setAsync(sessionId, sessionString, - std::chrono::seconds(expiration)); -} - -bool SSLSessionCacheManager::lookupCacheRecord(const string& sessionId, - AsyncSSLSocket* sslSocket) { - auto cacheCtx = new SSLCacheProvider::CacheContext(); - cacheCtx->sessionId = sessionId; - cacheCtx->session = nullptr; - cacheCtx->sslSocket = sslSocket; - cacheCtx->guard.reset( - new DelayedDestruction::DestructorGuard(cacheCtx->sslSocket)); - cacheCtx->manager = this; - bool res = externalCache_->getAsync(sessionId, cacheCtx); - if (!res) { - delete cacheCtx; - } - return res; -} - -void SSLSessionCacheManager::restartSSLAccept( - const SSLCacheProvider::CacheContext* cacheCtx) { - PendingLookupMap::iterator pit = pendingLookups_.find(cacheCtx->sessionId); - CHECK(pit != pendingLookups_.end()); - pit->second.request_in_progress = false; - pit->second.session = cacheCtx->session; - VLOG(7) << "Restart SSL accept"; - cacheCtx->sslSocket->restartSSLAccept(); - for (const auto& attachedLookup: pit->second.waiters) { - // Wake up anyone else who was waiting for this session - VLOG(4) << "Restart SSL accept (waiters) for fd=" << - attachedLookup.first->getFd(); - attachedLookup.first->restartSSLAccept(); - } - pendingLookups_.erase(pit); -} - -void SSLSessionCacheManager::onGetSuccess( - SSLCacheProvider::CacheContext* cacheCtx, - const std::string& value) { - const uint8_t* cp = (uint8_t*)value.data(); - cacheCtx->session = d2i_SSL_SESSION(nullptr, &cp, value.length()); - restartSSLAccept(cacheCtx); - - /* Insert in the LRU after restarting all clients. The stats logic - * in getSession would treat this as a local hit otherwise. - */ - localCache_->storeSession(cacheCtx->sessionId, cacheCtx->session, stats_); - delete cacheCtx; -} - -void SSLSessionCacheManager::onGetFailure( - SSLCacheProvider::CacheContext* cacheCtx) { - restartSSLAccept(cacheCtx); - delete cacheCtx; -} - -} // namespace diff --git a/folly/experimental/wangle/ssl/SSLSessionCacheManager.h b/folly/experimental/wangle/ssl/SSLSessionCacheManager.h deleted file mode 100644 index f9c9e5de..00000000 --- a/folly/experimental/wangle/ssl/SSLSessionCacheManager.h +++ /dev/null @@ -1,292 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -#include -#include - -#include -#include -#include - -namespace folly { - -class SSLStats; - -/** - * Basic SSL session cache map: Maps session id -> session - */ -typedef folly::EvictingCacheMap SSLSessionCacheMap; - -/** - * Holds an SSLSessionCacheMap and associated lock - */ -class LocalSSLSessionCache: private boost::noncopyable { - public: - LocalSSLSessionCache(uint32_t maxCacheSize, uint32_t cacheCullSize); - - ~LocalSSLSessionCache() { - std::lock_guard g(lock); - // EvictingCacheMap dtor doesn't free values - sessionCache.clear(); - } - - SSLSessionCacheMap sessionCache; - std::mutex lock; - uint32_t removedSessions_{0}; - - private: - - void pruneSessionCallback(const std::string& sessionId, - SSL_SESSION* session); -}; - -/** - * A sharded LRU for SSL sessions. The sharding is inteneded to reduce - * contention for the LRU locks. Assuming uniform distribution, two workers - * will contend for the same lock with probability 1 / n_buckets^2. - */ -class ShardedLocalSSLSessionCache : private boost::noncopyable { - public: - ShardedLocalSSLSessionCache(uint32_t n_buckets, uint32_t maxCacheSize, - uint32_t cacheCullSize) { - CHECK(n_buckets > 0); - maxCacheSize = (uint32_t)(((double)maxCacheSize) / n_buckets); - cacheCullSize = (uint32_t)(((double)cacheCullSize) / n_buckets); - if (maxCacheSize == 0) { - maxCacheSize = 1; - } - if (cacheCullSize == 0) { - cacheCullSize = 1; - } - for (uint32_t i = 0; i < n_buckets; i++) { - caches_.push_back( - std::unique_ptr( - new LocalSSLSessionCache(maxCacheSize, cacheCullSize))); - } - } - - SSL_SESSION* lookupSession(const std::string& sessionId) { - size_t bucket = hash(sessionId); - SSL_SESSION* session = nullptr; - std::lock_guard g(caches_[bucket]->lock); - - auto itr = caches_[bucket]->sessionCache.find(sessionId); - if (itr != caches_[bucket]->sessionCache.end()) { - session = itr->second; - } - - if (session) { - CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION); - } - return session; - } - - void storeSession(const std::string& sessionId, SSL_SESSION* session, - SSLStats* stats) { - size_t bucket = hash(sessionId); - SSL_SESSION* oldSession = nullptr; - std::lock_guard g(caches_[bucket]->lock); - - auto itr = caches_[bucket]->sessionCache.find(sessionId); - if (itr != caches_[bucket]->sessionCache.end()) { - oldSession = itr->second; - } - - if (oldSession) { - // LRUCacheMap doesn't free on overwrite, so 2x the work for us - // This can happen in race conditions - SSL_SESSION_free(oldSession); - } - caches_[bucket]->removedSessions_ = 0; - caches_[bucket]->sessionCache.set(sessionId, session, true); - if (stats) { - stats->recordSSLSessionFree(caches_[bucket]->removedSessions_); - } - } - - void removeSession(const std::string& sessionId) { - size_t bucket = hash(sessionId); - std::lock_guard g(caches_[bucket]->lock); - caches_[bucket]->sessionCache.erase(sessionId); - } - - private: - - /* SSL session IDs are 32 bytes of random data, hash based on first 16 bits */ - size_t hash(const std::string& key) { - CHECK(key.length() >= 2); - return (key[0] << 8 | key[1]) % caches_.size(); - } - - std::vector< std::unique_ptr > caches_; -}; - -/* A socket/DestructorGuard pair */ -typedef std::pair> - AttachedLookup; - -/** - * PendingLookup structure - * - * Keeps track of clients waiting for an SSL session to be retrieved from - * the external cache provider. - */ -struct PendingLookup { - bool request_in_progress; - SSL_SESSION* session; - std::list waiters; - - PendingLookup() { - request_in_progress = true; - session = nullptr; - } -}; - -/* Maps SSL session id to a PendingLookup structure */ -typedef std::map PendingLookupMap; - -/** - * SSLSessionCacheManager handles all stateful session caching. There is an - * instance of this object per SSL VIP per thread, with a 1:1 correlation with - * SSL_CTX. The cache can work locally or in concert with an external cache - * to share sessions across instances. - * - * There is a single in memory session cache shared by all VIPs. The cache is - * split into N buckets (currently 16) with a separate lock per bucket. The - * VIP ID is hashed and stored as part of the session to handle the - * (very unlikely) case of session ID collision. - * - * When a new SSL session is created, it is added to the LRU cache and - * sent to the external cache to be stored. The external cache - * expiration is equal to the SSL session's expiration. - * - * When a resume request is received, SSLSessionCacheManager first looks in the - * local LRU cache for the VIP. If there is a miss there, an asynchronous - * request for this session is dispatched to the external cache. When the - * external cache query returns, the LRU cache is updated if the session was - * found, and the SSL_accept call is resumed. - * - * If additional resume requests for the same session ID arrive in the same - * thread while the request is pending, the 2nd - Nth callers attach to the - * original external cache requests and are resumed when it comes back. No - * attempt is made to coalesce external cache requests for the same session - * ID in different worker threads. Previous work did this, but the - * complexity was deemed to outweigh the potential savings. - * - */ -class SSLSessionCacheManager : private boost::noncopyable { - public: - /** - * Constructor. SSL session related callbacks will be set on the underlying - * SSL_CTX. vipId is assumed to a unique string identifying the VIP and must - * be the same on all servers that wish to share sessions via the same - * external cache. - */ - SSLSessionCacheManager( - uint32_t maxCacheSize, - uint32_t cacheCullSize, - SSLContext* ctx, - const folly::SocketAddress& sockaddr, - const std::string& context, - EventBase* eventBase, - SSLStats* stats, - const std::shared_ptr& externalCache); - - virtual ~SSLSessionCacheManager(); - - /** - * Call this on shutdown to release the global instance of the - * ShardedLocalSSLSessionCache. - */ - static void shutdown(); - - /** - * Callback for ExternalCache to call when an async get succeeds - * @param context The context that was passed to the async get request - * @param value Serialized session - */ - void onGetSuccess(SSLCacheProvider::CacheContext* context, - const std::string& value); - - /** - * Callback for ExternalCache to call when an async get fails, either - * because the requested session is not in the external cache or because - * of an error. - * @param context The context that was passed to the async get request - */ - void onGetFailure(SSLCacheProvider::CacheContext* context); - - private: - - SSLContext* ctx_; - std::shared_ptr localCache_; - PendingLookupMap pendingLookups_; - SSLStats* stats_{nullptr}; - std::shared_ptr externalCache_; - - /** - * Invoked by openssl when a new SSL session is created - */ - int newSession(SSL* ssl, SSL_SESSION* session); - - /** - * Invoked by openssl when an SSL session is ejected from its internal cache. - * This can't be invoked in the current implementation because SSL's internal - * caching is disabled. - */ - void removeSession(SSL_CTX* ctx, SSL_SESSION* session); - - /** - * Invoked by openssl when a client requests a stateful session resumption. - * Triggers a lookup in our local cache and potentially an asynchronous - * request to an external cache. - */ - SSL_SESSION* getSession(SSL* ssl, unsigned char* session_id, - int id_len, int* copyflag); - - /** - * Store a new session record in the external cache - */ - bool storeCacheRecord(const std::string& sessionId, SSL_SESSION* session); - - /** - * Lookup a session in the external cache for the specified SSL socket. - */ - bool lookupCacheRecord(const std::string& sessionId, - AsyncSSLSocket* sslSock); - - /** - * Restart all clients waiting for the answer to an external cache query - */ - void restartSSLAccept(const SSLCacheProvider::CacheContext* cacheCtx); - - /** - * Get or create the LRU cache for the given VIP ID - */ - static std::shared_ptr getLocalCache( - uint32_t maxCacheSize, uint32_t cacheCullSize); - - /** - * static functions registered as callbacks to openssl via - * SSL_CTX_sess_set_new/get/remove_cb - */ - static int newSessionCallback(SSL* ssl, SSL_SESSION* session); - static void removeSessionCallback(SSL_CTX* ctx, SSL_SESSION* session); - static SSL_SESSION* getSessionCallback(SSL* ssl, unsigned char* session_id, - int id_len, int* copyflag); - - static int32_t sExDataIndex_; - static std::shared_ptr sCache_; - static std::mutex sCacheLock_; -}; - -} diff --git a/folly/experimental/wangle/ssl/SSLStats.h b/folly/experimental/wangle/ssl/SSLStats.h deleted file mode 100644 index 761a8434..00000000 --- a/folly/experimental/wangle/ssl/SSLStats.h +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -namespace folly { - -class SSLStats { - public: - virtual ~SSLStats() noexcept {} - - // downstream - virtual void recordSSLAcceptLatency(int64_t latency) noexcept = 0; - virtual void recordTLSTicket(bool ticketNew, bool ticketHit) noexcept = 0; - virtual void recordSSLSession(bool sessionNew, bool sessionHit, bool foreign) - noexcept = 0; - virtual void recordSSLSessionRemove() noexcept = 0; - virtual void recordSSLSessionFree(uint32_t freed) noexcept = 0; - virtual void recordSSLSessionSetError(uint32_t err) noexcept = 0; - virtual void recordSSLSessionGetError(uint32_t err) noexcept = 0; - virtual void recordClientRenegotiation() noexcept = 0; - - // upstream - virtual void recordSSLUpstreamConnection(bool handshake) noexcept = 0; - virtual void recordSSLUpstreamConnectionError(bool verifyError) noexcept = 0; - virtual void recordCryptoSSLExternalAttempt() noexcept = 0; - virtual void recordCryptoSSLExternalConnAlreadyClosed() noexcept = 0; - virtual void recordCryptoSSLExternalApplicationException() noexcept = 0; - virtual void recordCryptoSSLExternalSuccess() noexcept = 0; - virtual void recordCryptoSSLExternalDuration(uint64_t duration) noexcept = 0; - virtual void recordCryptoSSLLocalAttempt() noexcept = 0; - virtual void recordCryptoSSLLocalSuccess() noexcept = 0; - -}; - -} diff --git a/folly/experimental/wangle/ssl/SSLUtil.cpp b/folly/experimental/wangle/ssl/SSLUtil.cpp deleted file mode 100644 index 5557d1e4..00000000 --- a/folly/experimental/wangle/ssl/SSLUtil.cpp +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#include - -#include - -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL -#define OPENSSL_GE_101 1 -#include -#include -#else -#undef OPENSSL_GE_101 -#endif - -namespace folly { - -std::mutex SSLUtil::sIndexLock_; - -std::unique_ptr SSLUtil::getCommonName(const X509* cert) { - X509_NAME* subject = X509_get_subject_name((X509*)cert); - if (!subject) { - return nullptr; - } - char cn[ub_common_name + 1]; - int res = X509_NAME_get_text_by_NID(subject, NID_commonName, - cn, ub_common_name); - if (res <= 0) { - return nullptr; - } else { - cn[ub_common_name] = '\0'; - return folly::make_unique(cn); - } -} - -std::unique_ptr> SSLUtil::getSubjectAltName( - const X509* cert) { -#ifdef OPENSSL_GE_101 - auto nameList = folly::make_unique>(); - GENERAL_NAMES* names = (GENERAL_NAMES*)X509_get_ext_d2i( - (X509*)cert, NID_subject_alt_name, nullptr, nullptr); - if (names) { - auto guard = folly::makeGuard([names] { GENERAL_NAMES_free(names); }); - size_t count = sk_GENERAL_NAME_num(names); - CHECK(count < std::numeric_limits::max()); - for (int i = 0; i < (int)count; ++i) { - GENERAL_NAME* generalName = sk_GENERAL_NAME_value(names, i); - if (generalName->type == GEN_DNS) { - ASN1_STRING* s = generalName->d.dNSName; - const char* name = (const char*)ASN1_STRING_data(s); - // I can't find any docs on what a negative return value here - // would mean, so I'm going to ignore it. - auto len = ASN1_STRING_length(s); - DCHECK(len >= 0); - if (size_t(len) != strlen(name)) { - // Null byte(s) in the name; return an error rather than depending on - // the caller to safely handle this case. - return nullptr; - } - nameList->emplace_back(name); - } - } - } - return nameList; -#else - return nullptr; -#endif -} - -} diff --git a/folly/experimental/wangle/ssl/SSLUtil.h b/folly/experimental/wangle/ssl/SSLUtil.h deleted file mode 100644 index 20a17a95..00000000 --- a/folly/experimental/wangle/ssl/SSLUtil.h +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -#include -#include -#include - -namespace folly { - -/** - * SSL session establish/resume status - * - * changing these values will break logging pipelines - */ -enum class SSLResumeEnum : uint8_t { - HANDSHAKE = 0, - RESUME_SESSION_ID = 1, - RESUME_TICKET = 3, - NA = 2 -}; - -enum class SSLErrorEnum { - NO_ERROR, - TIMEOUT, - DROPPED -}; - -class SSLUtil { - private: - static std::mutex sIndexLock_; - - public: - /** - * Ensures only one caller will allocate an ex_data index for a given static - * or global. - */ - static void getSSLCtxExIndex(int* pindex) { - std::lock_guard g(sIndexLock_); - if (*pindex < 0) { - *pindex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); - } - } - - static void getRSAExIndex(int* pindex) { - std::lock_guard g(sIndexLock_); - if (*pindex < 0) { - *pindex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); - } - } - - static inline std::string hexlify(const std::string& binary) { - std::string hex; - folly::hexlify(binary, hex); - - return hex; - } - - static inline const std::string& hexlify(const std::string& binary, - std::string& hex) { - folly::hexlify(binary, hex); - - return hex; - } - - /** - * Return the SSL resume type for the given socket. - */ - static inline SSLResumeEnum getResumeState( - AsyncSSLSocket* sslSocket) { - return sslSocket->getSSLSessionReused() ? - (sslSocket->sessionIDResumed() ? - SSLResumeEnum::RESUME_SESSION_ID : - SSLResumeEnum::RESUME_TICKET) : - SSLResumeEnum::HANDSHAKE; - } - - /** - * Get the Common Name from an X.509 certificate - * @param cert certificate to inspect - * @return common name, or null if an error occurs - */ - static std::unique_ptr getCommonName(const X509* cert); - - /** - * Get the Subject Alternative Name value(s) from an X.509 certificate - * @param cert certificate to inspect - * @return set of zero or more alternative names, or null if - * an error occurs - */ - static std::unique_ptr> getSubjectAltName( - const X509* cert); -}; - -} diff --git a/folly/experimental/wangle/ssl/TLSTicketKeyManager.cpp b/folly/experimental/wangle/ssl/TLSTicketKeyManager.cpp deleted file mode 100644 index c02153a5..00000000 --- a/folly/experimental/wangle/ssl/TLSTicketKeyManager.cpp +++ /dev/null @@ -1,305 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#include - -#include -#include - -#include -#include -#include -#include -#include - -#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB -using std::string; - -namespace { - -const int kTLSTicketKeyNameLen = 4; -const int kTLSTicketKeySaltLen = 12; - -} - -namespace folly { - - -// TLSTicketKeyManager Implementation -int32_t TLSTicketKeyManager::sExDataIndex_ = -1; - -TLSTicketKeyManager::TLSTicketKeyManager(SSLContext* ctx, SSLStats* stats) - : ctx_(ctx), - randState_(0), - stats_(stats) { - SSLUtil::getSSLCtxExIndex(&sExDataIndex_); - SSL_CTX_set_ex_data(ctx_->getSSLCtx(), sExDataIndex_, this); -} - -TLSTicketKeyManager::~TLSTicketKeyManager() { -} - -int -TLSTicketKeyManager::callback(SSL* ssl, unsigned char* keyName, - unsigned char* iv, - EVP_CIPHER_CTX* cipherCtx, - HMAC_CTX* hmacCtx, int encrypt) { - TLSTicketKeyManager* manager = nullptr; - SSL_CTX* ctx = SSL_get_SSL_CTX(ssl); - manager = (TLSTicketKeyManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_); - - if (manager == nullptr) { - LOG(FATAL) << "Null TLSTicketKeyManager in callback" ; - return -1; - } - return manager->processTicket(ssl, keyName, iv, cipherCtx, hmacCtx, encrypt); -} - -int -TLSTicketKeyManager::processTicket(SSL* ssl, unsigned char* keyName, - unsigned char* iv, - EVP_CIPHER_CTX* cipherCtx, - HMAC_CTX* hmacCtx, int encrypt) { - uint8_t salt[kTLSTicketKeySaltLen]; - uint8_t* saltptr = nullptr; - uint8_t output[SHA256_DIGEST_LENGTH]; - uint8_t* hmacKey = nullptr; - uint8_t* aesKey = nullptr; - TLSTicketKeySource* key = nullptr; - int result = 0; - - if (encrypt) { - key = findEncryptionKey(); - if (key == nullptr) { - // no keys available to encrypt - VLOG(2) << "No TLS ticket key found"; - return -1; - } - VLOG(4) << "Encrypting new ticket with key name=" << - SSLUtil::hexlify(key->keyName_); - - // Get a random salt and write out key name - RAND_pseudo_bytes(salt, (int)sizeof(salt)); - memcpy(keyName, key->keyName_.data(), kTLSTicketKeyNameLen); - memcpy(keyName + kTLSTicketKeyNameLen, salt, kTLSTicketKeySaltLen); - - // Create the unique keys by hashing with the salt - makeUniqueKeys(key->keySource_, sizeof(key->keySource_), salt, output); - // This relies on the fact that SHA256 has 32 bytes of output - // and that AES-128 keys are 16 bytes - hmacKey = output; - aesKey = output + SHA256_DIGEST_LENGTH / 2; - - // Initialize iv and cipher/mac CTX - RAND_pseudo_bytes(iv, AES_BLOCK_SIZE); - HMAC_Init_ex(hmacCtx, hmacKey, SHA256_DIGEST_LENGTH / 2, - EVP_sha256(), nullptr); - EVP_EncryptInit_ex(cipherCtx, EVP_aes_128_cbc(), nullptr, aesKey, iv); - - result = 1; - } else { - key = findDecryptionKey(keyName); - if (key == nullptr) { - // no ticket found for decryption - will issue a new ticket - if (VLOG_IS_ON(4)) { - string skeyName((char *)keyName, kTLSTicketKeyNameLen); - VLOG(4) << "Can't find ticket key with name=" << - SSLUtil::hexlify(skeyName)<< ", will generate new ticket"; - } - - result = 0; - } else { - VLOG(4) << "Decrypting ticket with key name=" << - SSLUtil::hexlify(key->keyName_); - - // Reconstruct the unique key via the salt - saltptr = keyName + kTLSTicketKeyNameLen; - makeUniqueKeys(key->keySource_, sizeof(key->keySource_), saltptr, output); - hmacKey = output; - aesKey = output + SHA256_DIGEST_LENGTH / 2; - - // Initialize cipher/mac CTX - HMAC_Init_ex(hmacCtx, hmacKey, SHA256_DIGEST_LENGTH / 2, - EVP_sha256(), nullptr); - EVP_DecryptInit_ex(cipherCtx, EVP_aes_128_cbc(), nullptr, aesKey, iv); - - result = 1; - } - } - // result records whether a ticket key was found to decrypt this ticket, - // not wether the session was re-used. - if (stats_) { - stats_->recordTLSTicket(encrypt, result); - } - - return result; -} - -bool -TLSTicketKeyManager::setTLSTicketKeySeeds( - const std::vector& oldSeeds, - const std::vector& currentSeeds, - const std::vector& newSeeds) { - - bool result = true; - - activeKeys_.clear(); - ticketKeys_.clear(); - ticketSeeds_.clear(); - const std::vector *seedList = &oldSeeds; - for (uint32_t i = 0; i < 3; i++) { - TLSTicketSeedType type = (TLSTicketSeedType)i; - if (type == SEED_CURRENT) { - seedList = ¤tSeeds; - } else if (type == SEED_NEW) { - seedList = &newSeeds; - } - - for (const auto& seedInput: *seedList) { - TLSTicketSeed* seed = insertSeed(seedInput, type); - if (seed == nullptr) { - result = false; - continue; - } - insertNewKey(seed, 1, nullptr); - } - } - if (!result) { - VLOG(2) << "One or more seeds failed to decode"; - } - - if (ticketKeys_.size() == 0 || activeKeys_.size() == 0) { - LOG(WARNING) << "No keys configured, falling back to default"; - SSL_CTX_set_tlsext_ticket_key_cb(ctx_->getSSLCtx(), nullptr); - return false; - } - SSL_CTX_set_tlsext_ticket_key_cb(ctx_->getSSLCtx(), - TLSTicketKeyManager::callback); - - return true; -} - -string -TLSTicketKeyManager::makeKeyName(TLSTicketSeed* seed, uint32_t n, - unsigned char* nameBuf) { - SHA256_CTX ctx; - - SHA256_Init(&ctx); - SHA256_Update(&ctx, seed->seedName_, sizeof(seed->seedName_)); - SHA256_Update(&ctx, &n, sizeof(n)); - SHA256_Final(nameBuf, &ctx); - return string((char *)nameBuf, kTLSTicketKeyNameLen); -} - -TLSTicketKeyManager::TLSTicketKeySource* -TLSTicketKeyManager::insertNewKey(TLSTicketSeed* seed, uint32_t hashCount, - TLSTicketKeySource* prevKey) { - unsigned char nameBuf[SHA256_DIGEST_LENGTH]; - std::unique_ptr newKey(new TLSTicketKeySource); - - // This function supports hash chaining but it is not currently used. - - if (prevKey != nullptr) { - hashNth(prevKey->keySource_, sizeof(prevKey->keySource_), - newKey->keySource_, 1); - } else { - // can't go backwards or the current is missing, start from the beginning - hashNth((unsigned char *)seed->seed_.data(), seed->seed_.length(), - newKey->keySource_, hashCount); - } - - newKey->hashCount_ = hashCount; - newKey->keyName_ = makeKeyName(seed, hashCount, nameBuf); - newKey->type_ = seed->type_; - auto it = ticketKeys_.insert(std::make_pair(newKey->keyName_, - std::move(newKey))); - - auto key = it.first->second.get(); - if (key->type_ == SEED_CURRENT) { - activeKeys_.push_back(key); - } - VLOG(4) << "Adding key for " << hashCount << " type=" << - (uint32_t)key->type_ << " Name=" << SSLUtil::hexlify(key->keyName_); - - return key; -} - -void -TLSTicketKeyManager::hashNth(const unsigned char* input, size_t input_len, - unsigned char* output, uint32_t n) { - assert(n > 0); - for (uint32_t i = 0; i < n; i++) { - SHA256(input, input_len, output); - input = output; - input_len = SHA256_DIGEST_LENGTH; - } -} - -TLSTicketKeyManager::TLSTicketSeed * -TLSTicketKeyManager::insertSeed(const string& seedInput, - TLSTicketSeedType type) { - TLSTicketSeed* seed = nullptr; - string seedOutput; - - if (!folly::unhexlify(seedInput, seedOutput)) { - LOG(WARNING) << "Failed to decode seed type=" << (uint32_t)type << - " seed=" << seedInput; - return seed; - } - - seed = new TLSTicketSeed(); - seed->seed_ = seedOutput; - seed->type_ = type; - SHA256((unsigned char *)seedOutput.data(), seedOutput.length(), - seed->seedName_); - ticketSeeds_.push_back(std::unique_ptr(seed)); - - return seed; -} - -TLSTicketKeyManager::TLSTicketKeySource * -TLSTicketKeyManager::findEncryptionKey() { - TLSTicketKeySource* result = nullptr; - // call to rand here is a bit hokey since it's not cryptographically - // random, and is predictably seeded with 0. However, activeKeys_ - // is probably not going to have very many keys in it, and most - // likely only 1. - size_t numKeys = activeKeys_.size(); - if (numKeys > 0) { - result = activeKeys_[rand_r(&randState_) % numKeys]; - } - return result; -} - -TLSTicketKeyManager::TLSTicketKeySource * -TLSTicketKeyManager::findDecryptionKey(unsigned char* keyName) { - string name((char *)keyName, kTLSTicketKeyNameLen); - TLSTicketKeySource* key = nullptr; - TLSTicketKeyMap::iterator mapit = ticketKeys_.find(name); - if (mapit != ticketKeys_.end()) { - key = mapit->second.get(); - } - return key; -} - -void -TLSTicketKeyManager::makeUniqueKeys(unsigned char* parentKey, - size_t keyLen, - unsigned char* salt, - unsigned char* output) { - SHA256_CTX hash_ctx; - - SHA256_Init(&hash_ctx); - SHA256_Update(&hash_ctx, parentKey, keyLen); - SHA256_Update(&hash_ctx, salt, kTLSTicketKeySaltLen); - SHA256_Final(output, &hash_ctx); -} - -} // namespace -#endif diff --git a/folly/experimental/wangle/ssl/TLSTicketKeyManager.h b/folly/experimental/wangle/ssl/TLSTicketKeyManager.h deleted file mode 100644 index 4000c139..00000000 --- a/folly/experimental/wangle/ssl/TLSTicketKeyManager.h +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -#include -#include - -namespace folly { - -#ifndef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB -class TLSTicketKeyManager {}; -#else -class SSLStats; -/** - * The TLSTicketKeyManager handles TLS ticket key encryption and decryption in - * a way that facilitates sharing the ticket keys across a range of servers. - * Hash chaining is employed to achieve frequent key rotation with minimal - * configuration change. The scheme is as follows: - * - * The manager is supplied with three lists of seeds (old, current and new). - * The config should be updated with new seeds periodically (e.g., daily). - * 3 config changes are recommended to achieve the smoothest seed rotation - * eg: - * 1. Introduce new seed in the push prior to rotation - * 2. Rotation push - * 3. Remove old seeds in the push following rotation - * - * Multiple seeds are supported but only a single seed is required. - * - * Generating encryption keys from the seed works as follows. For a given - * seed, hash forward N times where N is currently the constant 1. - * This is the base key. The name of the base key is the first 4 - * bytes of hash(hash(seed), N). This is copied into the first 4 bytes of the - * TLS ticket key name field. - * - * For each new ticket encryption, the manager generates a random 12 byte salt. - * Hash the salt and the base key together to form the encryption key for - * that ticket. The salt is included in the ticket's 'key name' field so it - * can be used to derive the decryption key. The salt is copied into the second - * 8 bytes of the TLS ticket key name field. - * - * A key is valid for decryption for the lifetime of the instance. - * Sessions will be valid for less time than that, which results in an extra - * symmetric decryption to discover the session is expired. - * - * A TLSTicketKeyManager should be used in only one thread, and should have - * a 1:1 relationship with the SSLContext provided. - * - */ -class TLSTicketKeyManager : private boost::noncopyable { - public: - - explicit TLSTicketKeyManager(folly::SSLContext* ctx, - SSLStats* stats); - - virtual ~TLSTicketKeyManager(); - - /** - * SSL callback to set up encryption/decryption context for a TLS Ticket Key. - * - * This will be supplied to the SSL library via - * SSL_CTX_set_tlsext_ticket_key_cb. - */ - static int callback(SSL* ssl, unsigned char* keyName, - unsigned char* iv, - EVP_CIPHER_CTX* cipherCtx, - HMAC_CTX* hmacCtx, int encrypt); - - /** - * Initialize the manager with three sets of seeds. There must be at least - * one current seed, or the manager will revert to the default SSL behavior. - * - * @param oldSeeds Seeds previously used which can still decrypt. - * @param currentSeeds Seeds to use for new ticket encryptions. - * @param newSeeds Seeds which will be used soon, can be used to decrypt - * in case some servers in the cluster have already rotated. - */ - bool setTLSTicketKeySeeds(const std::vector& oldSeeds, - const std::vector& currentSeeds, - const std::vector& newSeeds); - - private: - enum TLSTicketSeedType { - SEED_OLD = 0, - SEED_CURRENT, - SEED_NEW - }; - - /* The seeds supplied by the configuration */ - struct TLSTicketSeed { - std::string seed_; - TLSTicketSeedType type_; - unsigned char seedName_[SHA256_DIGEST_LENGTH]; - }; - - struct TLSTicketKeySource { - int32_t hashCount_; - std::string keyName_; - TLSTicketSeedType type_; - unsigned char keySource_[SHA256_DIGEST_LENGTH]; - }; - - /** - * Method to setup encryption/decryption context for a TLS Ticket Key - * - * OpenSSL documentation is thin on the return value semantics. - * - * For encrypt=1, return < 0 on error, >= 0 for successfully initialized - * For encrypt=0, return < 0 on error, 0 on key not found - * 1 on key found, 2 renew_ticket - * - * renew_ticket means a new ticket will be issued. We could return this value - * when receiving a ticket encrypted with a key derived from an OLD seed. - * However, session_timeout seconds after deploying with a seed - * rotated from CURRENT -> OLD, there will be no valid tickets outstanding - * encrypted with the old key. This grace period means no unnecessary - * handshakes will be performed. If the seed is believed compromised, it - * should NOT be configured as an OLD seed. - */ - int processTicket(SSL* ssl, unsigned char* keyName, - unsigned char* iv, - EVP_CIPHER_CTX* cipherCtx, - HMAC_CTX* hmacCtx, int encrypt); - - // Creates the name for the nth key generated from seed - std::string makeKeyName(TLSTicketSeed* seed, uint32_t n, - unsigned char* nameBuf); - - /** - * Creates the key hashCount hashes from the given seed and inserts it in - * ticketKeys. A naked pointer to the key is returned for additional - * processing if needed. - */ - TLSTicketKeySource* insertNewKey(TLSTicketSeed* seed, uint32_t hashCount, - TLSTicketKeySource* prevKeySource); - - /** - * hashes input N times placing result in output, which must be at least - * SHA256_DIGEST_LENGTH long. - */ - void hashNth(const unsigned char* input, size_t input_len, - unsigned char* output, uint32_t n); - - /** - * Adds the given seed to the manager - */ - TLSTicketSeed* insertSeed(const std::string& seedInput, - TLSTicketSeedType type); - - /** - * Locate a key for encrypting a new ticket - */ - TLSTicketKeySource* findEncryptionKey(); - - /** - * Locate a key for decrypting a ticket with the given keyName - */ - TLSTicketKeySource* findDecryptionKey(unsigned char* keyName); - - /** - * Derive a unique key from the parent key and the salt via hashing - */ - void makeUniqueKeys(unsigned char* parentKey, size_t keyLen, - unsigned char* salt, unsigned char* output); - - /** - * For standalone decryption utility - */ - friend int decrypt_fb_ticket(folly::TLSTicketKeyManager* manager, - const std::string& testTicket, - SSL_SESSION **psess); - - typedef std::vector> TLSTicketSeedList; - typedef std::map > - TLSTicketKeyMap; - typedef std::vector TLSActiveKeyList; - - TLSTicketSeedList ticketSeeds_; - // All key sources that can be used for decryption - TLSTicketKeyMap ticketKeys_; - // Key sources that can be used for encryption - TLSActiveKeyList activeKeys_; - - folly::SSLContext* ctx_; - uint32_t randState_; - SSLStats* stats_{nullptr}; - - static int32_t sExDataIndex_; -}; -#endif -} diff --git a/folly/experimental/wangle/ssl/TLSTicketKeySeeds.h b/folly/experimental/wangle/ssl/TLSTicketKeySeeds.h deleted file mode 100644 index c40ae581..00000000 --- a/folly/experimental/wangle/ssl/TLSTicketKeySeeds.h +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#pragma once - -namespace folly { - -struct TLSTicketKeySeeds { - std::vector oldSeeds; - std::vector currentSeeds; - std::vector newSeeds; -}; - -} diff --git a/folly/experimental/wangle/ssl/test/SSLCacheTest.cpp b/folly/experimental/wangle/ssl/test/SSLCacheTest.cpp deleted file mode 100644 index 2433cfc0..00000000 --- a/folly/experimental/wangle/ssl/test/SSLCacheTest.cpp +++ /dev/null @@ -1,272 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace std; -using namespace folly; - -DEFINE_int32(clients, 1, "Number of simulated SSL clients"); -DEFINE_int32(threads, 1, "Number of threads to spread clients across"); -DEFINE_int32(requests, 2, "Total number of requests per client"); -DEFINE_int32(port, 9423, "Server port"); -DEFINE_bool(sticky, false, "A given client sends all reqs to one " - "(random) server"); -DEFINE_bool(global, false, "All clients in a thread use the same SSL session"); -DEFINE_bool(handshakes, false, "Force 100% handshakes"); - -string f_servers[10]; -int f_num_servers = 0; -int tnum = 0; - -class ClientRunner { - public: - - ClientRunner(): reqs(0), hits(0), miss(0), num(tnum++) {} - void run(); - - int reqs; - int hits; - int miss; - int num; -}; - -class SSLCacheClient : public AsyncSocket::ConnectCallback, - public AsyncSSLSocket::HandshakeCB -{ -private: - EventBase* eventBase_; - int currReq_; - int serverIdx_; - AsyncSocket* socket_; - AsyncSSLSocket* sslSocket_; - SSL_SESSION* session_; - SSL_SESSION **pSess_; - std::shared_ptr ctx_; - ClientRunner* cr_; - -public: - SSLCacheClient(EventBase* eventBase, SSL_SESSION **pSess, ClientRunner* cr); - ~SSLCacheClient() { - if (session_ && !FLAGS_global) - SSL_SESSION_free(session_); - if (socket_ != nullptr) { - if (sslSocket_ != nullptr) { - sslSocket_->destroy(); - sslSocket_ = nullptr; - } - socket_->destroy(); - socket_ = nullptr; - } - }; - - void start(); - - virtual void connectSuccess() noexcept; - - virtual void connectErr(const AsyncSocketException& ex) - noexcept ; - - virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept; - - virtual void handshakeErr( - AsyncSSLSocket* sock, - const AsyncSocketException& ex) noexcept; - -}; - -int -main(int argc, char* argv[]) -{ - gflags::SetUsageMessage(std::string("\n\n" -"usage: sslcachetest [options] -c -t servers\n" -)); - gflags::ParseCommandLineFlags(&argc, &argv, true); - int reqs = 0; - int hits = 0; - int miss = 0; - struct timeval start; - struct timeval end; - struct timeval result; - - srand((unsigned int)time(nullptr)); - - for (int i = 1; i < argc; i++) { - f_servers[f_num_servers++] = argv[i]; - } - if (f_num_servers == 0) { - cout << "require at least one server\n"; - return 1; - } - - gettimeofday(&start, nullptr); - if (FLAGS_threads == 1) { - ClientRunner r; - r.run(); - gettimeofday(&end, nullptr); - reqs = r.reqs; - hits = r.hits; - miss = r.miss; - } - else { - std::vector clients; - std::vector threads; - for (int t = 0; t < FLAGS_threads; t++) { - threads.emplace_back([&] { - clients[t].run(); - }); - } - for (auto& thr: threads) { - thr.join(); - } - gettimeofday(&end, nullptr); - - for (const auto& client: clients) { - reqs += client.reqs; - hits += client.hits; - miss += client.miss; - } - } - - timersub(&end, &start, &result); - - cout << "Requests: " << reqs << endl; - cout << "Handshakes: " << miss << endl; - cout << "Resumes: " << hits << endl; - cout << "Runtime(ms): " << result.tv_sec << "." << result.tv_usec / 1000 << - endl; - - cout << "ops/sec: " << (reqs * 1.0) / - ((double)result.tv_sec * 1.0 + (double)result.tv_usec / 1000000.0) << endl; - - return 0; -} - -void -ClientRunner::run() -{ - EventBase eb; - std::list clients; - SSL_SESSION* session = nullptr; - - for (int i = 0; i < FLAGS_clients; i++) { - SSLCacheClient* c = new SSLCacheClient(&eb, &session, this); - c->start(); - clients.push_back(c); - } - - eb.loop(); - - for (auto it = clients.begin(); it != clients.end(); it++) { - delete* it; - } - - reqs += hits + miss; -} - -SSLCacheClient::SSLCacheClient(EventBase* eb, - SSL_SESSION **pSess, - ClientRunner* cr) - : eventBase_(eb), - currReq_(0), - serverIdx_(0), - socket_(nullptr), - sslSocket_(nullptr), - session_(nullptr), - pSess_(pSess), - cr_(cr) -{ - ctx_.reset(new SSLContext()); - ctx_->setOptions(SSL_OP_NO_TICKET); -} - -void -SSLCacheClient::start() -{ - if (currReq_ >= FLAGS_requests) { - cout << "+"; - return; - } - - if (currReq_ == 0 || !FLAGS_sticky) { - serverIdx_ = rand() % f_num_servers; - } - if (socket_ != nullptr) { - if (sslSocket_ != nullptr) { - sslSocket_->destroy(); - sslSocket_ = nullptr; - } - socket_->destroy(); - socket_ = nullptr; - } - socket_ = new AsyncSocket(eventBase_); - socket_->connect(this, f_servers[serverIdx_], (uint16_t)FLAGS_port); -} - -void -SSLCacheClient::connectSuccess() noexcept -{ - sslSocket_ = new AsyncSSLSocket(ctx_, eventBase_, socket_->detachFd(), - false); - - if (!FLAGS_handshakes) { - if (session_ != nullptr) - sslSocket_->setSSLSession(session_); - else if (FLAGS_global && pSess_ != nullptr) - sslSocket_->setSSLSession(*pSess_); - } - sslSocket_->sslConn(this); -} - -void -SSLCacheClient::connectErr(const AsyncSocketException& ex) - noexcept -{ - cout << "connectError: " << ex.what() << endl; -} - -void -SSLCacheClient::handshakeSuc(AsyncSSLSocket* socket) noexcept -{ - if (sslSocket_->getSSLSessionReused()) { - cr_->hits++; - } else { - cr_->miss++; - if (session_ != nullptr) { - SSL_SESSION_free(session_); - } - session_ = sslSocket_->getSSLSession(); - if (FLAGS_global && pSess_ != nullptr && *pSess_ == nullptr) { - *pSess_ = session_; - } - } - if ( ((cr_->hits + cr_->miss) % 100) == ((100 / FLAGS_threads) * cr_->num)) { - cout << "."; - cout.flush(); - } - sslSocket_->closeNow(); - currReq_++; - this->start(); -} - -void -SSLCacheClient::handshakeErr( - AsyncSSLSocket* sock, - const AsyncSocketException& ex) - noexcept -{ - cout << "handshakeError: " << ex.what() << endl; -} diff --git a/folly/experimental/wangle/ssl/test/SSLContextManagerTest.cpp b/folly/experimental/wangle/ssl/test/SSLContextManagerTest.cpp deleted file mode 100644 index 6e5815c0..00000000 --- a/folly/experimental/wangle/ssl/test/SSLContextManagerTest.cpp +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 2014, Facebook, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. An additional grant - * of patent rights can be found in the PATENTS file in the same directory. - * - */ -#include -#include -#include -#include -#include -#include - -using std::shared_ptr; - -namespace folly { - -TEST(SSLContextManagerTest, Test1) -{ - EventBase eventBase; - SSLContextManager sslCtxMgr(&eventBase, "vip_ssl_context_manager_test_", - true, nullptr); - auto www_facebook_com_ctx = std::make_shared(); - auto start_facebook_com_ctx = std::make_shared(); - auto start_abc_facebook_com_ctx = std::make_shared(); - - sslCtxMgr.insertSSLCtxByDomainName( - "www.facebook.com", - strlen("www.facebook.com"), - www_facebook_com_ctx); - sslCtxMgr.insertSSLCtxByDomainName( - "www.facebook.com", - strlen("www.facebook.com"), - www_facebook_com_ctx); - try { - sslCtxMgr.insertSSLCtxByDomainName( - "www.facebook.com", - strlen("www.facebook.com"), - std::make_shared()); - } catch (const std::exception& ex) { - } - sslCtxMgr.insertSSLCtxByDomainName( - "*.facebook.com", - strlen("*.facebook.com"), - start_facebook_com_ctx); - sslCtxMgr.insertSSLCtxByDomainName( - "*.abc.facebook.com", - strlen("*.abc.facebook.com"), - start_abc_facebook_com_ctx); - try { - sslCtxMgr.insertSSLCtxByDomainName( - "*.abc.facebook.com", - strlen("*.abc.facebook.com"), - std::make_shared()); - FAIL(); - } catch (const std::exception& ex) { - } - - shared_ptr retCtx; - retCtx = sslCtxMgr.getSSLCtx(DNString("www.facebook.com")); - EXPECT_EQ(retCtx, www_facebook_com_ctx); - retCtx = sslCtxMgr.getSSLCtx(DNString("WWW.facebook.com")); - EXPECT_EQ(retCtx, www_facebook_com_ctx); - EXPECT_FALSE(sslCtxMgr.getSSLCtx(DNString("xyz.facebook.com"))); - - retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("xyz.facebook.com")); - EXPECT_EQ(retCtx, start_facebook_com_ctx); - retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("XYZ.facebook.com")); - EXPECT_EQ(retCtx, start_facebook_com_ctx); - - retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("www.abc.facebook.com")); - EXPECT_EQ(retCtx, start_abc_facebook_com_ctx); - - // ensure "facebook.com" does not match "*.facebook.com" - EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("facebook.com"))); - // ensure "Xfacebook.com" does not match "*.facebook.com" - EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("Xfacebook.com"))); - // ensure wildcard name only matches one domain up - EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("abc.xyz.facebook.com"))); - - eventBase.loop(); // Clean up events before SSLContextManager is destructed -} - -} diff --git a/folly/wangle/acceptor/Acceptor.cpp b/folly/wangle/acceptor/Acceptor.cpp new file mode 100644 index 00000000..d02ad248 --- /dev/null +++ b/folly/wangle/acceptor/Acceptor.cpp @@ -0,0 +1,437 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using folly::wangle::ConnectionManager; +using folly::wangle::ManagedConnection; +using std::chrono::microseconds; +using std::chrono::milliseconds; +using std::filebuf; +using std::ifstream; +using std::ios; +using std::shared_ptr; +using std::string; + +namespace folly { + +#ifndef NO_LIB_GFLAGS +DEFINE_int32(shutdown_idle_grace_ms, 5000, "milliseconds to wait before " + "closing idle conns"); +#else +const int32_t FLAGS_shutdown_idle_grace_ms = 5000; +#endif + +static const std::string empty_string; +std::atomic Acceptor::totalNumPendingSSLConns_{0}; + +/** + * Lightweight wrapper class to keep track of a newly + * accepted connection during SSL handshaking. + */ +class AcceptorHandshakeHelper : + public AsyncSSLSocket::HandshakeCB, + public ManagedConnection { + public: + AcceptorHandshakeHelper(AsyncSSLSocket::UniquePtr socket, + Acceptor* acceptor, + const SocketAddress& clientAddr, + std::chrono::steady_clock::time_point acceptTime) + : socket_(std::move(socket)), acceptor_(acceptor), + acceptTime_(acceptTime), clientAddr_(clientAddr) { + acceptor_->downstreamConnectionManager_->addConnection(this, true); + if(acceptor_->parseClientHello_) { + socket_->enableClientHelloParsing(); + } + socket_->sslAccept(this); + } + + virtual void timeoutExpired() noexcept { + VLOG(4) << "SSL handshake timeout expired"; + sslError_ = SSLErrorEnum::TIMEOUT; + dropConnection(); + } + virtual void describe(std::ostream& os) const { + os << "pending handshake on " << clientAddr_; + } + virtual bool isBusy() const { + return true; + } + virtual void notifyPendingShutdown() {} + virtual void closeWhenIdle() {} + + virtual void dropConnection() { + VLOG(10) << "Dropping in progress handshake for " << clientAddr_; + socket_->closeNow(); + } + virtual void dumpConnectionState(uint8_t loglevel) { + } + + private: + // AsyncSSLSocket::HandshakeCallback API + virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept { + + const unsigned char* nextProto = nullptr; + unsigned nextProtoLength = 0; + sock->getSelectedNextProtocol(&nextProto, &nextProtoLength); + if (VLOG_IS_ON(3)) { + if (nextProto) { + VLOG(3) << "Client selected next protocol " << + string((const char*)nextProto, nextProtoLength); + } else { + VLOG(3) << "Client did not select a next protocol"; + } + } + + // fill in SSL-related fields from TransportInfo + // the other fields like RTT are filled in the Acceptor + TransportInfo tinfo; + tinfo.ssl = true; + tinfo.acceptTime = acceptTime_; + tinfo.sslSetupTime = std::chrono::duration_cast(std::chrono::steady_clock::now() - acceptTime_); + tinfo.sslSetupBytesRead = sock->getRawBytesReceived(); + tinfo.sslSetupBytesWritten = sock->getRawBytesWritten(); + tinfo.sslServerName = sock->getSSLServerName(); + tinfo.sslCipher = sock->getNegotiatedCipherName(); + tinfo.sslVersion = sock->getSSLVersion(); + tinfo.sslCertSize = sock->getSSLCertSize(); + tinfo.sslResume = SSLUtil::getResumeState(sock); + sock->getSSLClientCiphers(tinfo.sslClientCiphers); + sock->getSSLServerCiphers(tinfo.sslServerCiphers); + tinfo.sslClientComprMethods = sock->getSSLClientComprMethods(); + tinfo.sslClientExts = sock->getSSLClientExts(); + tinfo.sslNextProtocol.assign( + reinterpret_cast(nextProto), + nextProtoLength); + + acceptor_->updateSSLStats(sock, tinfo.sslSetupTime, SSLErrorEnum::NO_ERROR); + acceptor_->downstreamConnectionManager_->removeConnection(this); + acceptor_->sslConnectionReady(std::move(socket_), clientAddr_, + nextProto ? string((const char*)nextProto, nextProtoLength) : + empty_string, tinfo); + delete this; + } + + virtual void handshakeErr(AsyncSSLSocket* sock, + const AsyncSocketException& ex) noexcept { + auto elapsedTime = std::chrono::duration_cast(std::chrono::steady_clock::now() - acceptTime_); + VLOG(3) << "SSL handshake error after " << elapsedTime.count() << + " ms; " << sock->getRawBytesReceived() << " bytes received & " << + sock->getRawBytesWritten() << " bytes sent: " << + ex.what(); + acceptor_->updateSSLStats(sock, elapsedTime, sslError_); + acceptor_->sslConnectionError(); + delete this; + } + + AsyncSSLSocket::UniquePtr socket_; + Acceptor* acceptor_; + std::chrono::steady_clock::time_point acceptTime_; + SocketAddress clientAddr_; + SSLErrorEnum sslError_{SSLErrorEnum::NO_ERROR}; +}; + +Acceptor::Acceptor(const ServerSocketConfig& accConfig) : + accConfig_(accConfig), + socketOptions_(accConfig.getSocketOptions()) { +} + +void +Acceptor::init(AsyncServerSocket* serverSocket, + EventBase* eventBase) { + CHECK(nullptr == this->base_); + + if (accConfig_.isSSL()) { + if (!sslCtxManager_) { + sslCtxManager_ = folly::make_unique( + eventBase, + "vip_" + getName(), + accConfig_.strictSSL, nullptr); + } + for (const auto& sslCtxConfig : accConfig_.sslContextConfigs) { + sslCtxManager_->addSSLContextConfig( + sslCtxConfig, + accConfig_.sslCacheOptions, + &accConfig_.initialTicketSeeds, + accConfig_.bindAddress, + cacheProvider_); + parseClientHello_ |= sslCtxConfig.clientHelloParsingEnabled; + } + + CHECK(sslCtxManager_->getDefaultSSLCtx()); + } + + base_ = eventBase; + state_ = State::kRunning; + downstreamConnectionManager_ = ConnectionManager::makeUnique( + eventBase, accConfig_.connectionIdleTimeout, this); + + if (serverSocket) { + serverSocket->addAcceptCallback(this, eventBase); + // SO_KEEPALIVE is the only setting that is inherited by accepted + // connections so only apply this setting + for (const auto& option: socketOptions_) { + if (option.first.level == SOL_SOCKET && + option.first.optname == SO_KEEPALIVE && option.second == 1) { + serverSocket->setKeepAliveEnabled(true); + break; + } + } + } +} + +Acceptor::~Acceptor(void) { +} + +void Acceptor::addSSLContextConfig(const SSLContextConfig& sslCtxConfig) { + sslCtxManager_->addSSLContextConfig(sslCtxConfig, + accConfig_.sslCacheOptions, + &accConfig_.initialTicketSeeds, + accConfig_.bindAddress, + cacheProvider_); +} + +void +Acceptor::drainAllConnections() { + if (downstreamConnectionManager_) { + downstreamConnectionManager_->initiateGracefulShutdown( + std::chrono::milliseconds(FLAGS_shutdown_idle_grace_ms)); + } +} + +void Acceptor::setLoadShedConfig(const LoadShedConfiguration& from, + IConnectionCounter* counter) { + loadShedConfig_ = from; + connectionCounter_ = counter; +} + +bool Acceptor::canAccept(const SocketAddress& address) { + if (!connectionCounter_) { + return true; + } + + uint64_t maxConnections = connectionCounter_->getMaxConnections(); + if (maxConnections == 0) { + return true; + } + + uint64_t currentConnections = connectionCounter_->getNumConnections(); + if (currentConnections < maxConnections) { + return true; + } + + if (loadShedConfig_.isWhitelisted(address)) { + return true; + } + + // Take care of comparing connection count against max connections across + // all acceptors. Expensive since a lock must be taken to get the counter. + auto connectionCountForLoadShedding = getConnectionCountForLoadShedding(); + if (connectionCountForLoadShedding < loadShedConfig_.getMaxConnections()) { + return true; + } + + VLOG(4) << address.describe() << " not whitelisted"; + return false; +} + +void +Acceptor::connectionAccepted( + int fd, const SocketAddress& clientAddr) noexcept { + if (!canAccept(clientAddr)) { + close(fd); + return; + } + auto acceptTime = std::chrono::steady_clock::now(); + for (const auto& opt: socketOptions_) { + opt.first.apply(fd, opt.second); + } + + onDoneAcceptingConnection(fd, clientAddr, acceptTime); +} + +void Acceptor::onDoneAcceptingConnection( + int fd, + const SocketAddress& clientAddr, + std::chrono::steady_clock::time_point acceptTime) noexcept { + processEstablishedConnection(fd, clientAddr, acceptTime); +} + +void +Acceptor::processEstablishedConnection( + int fd, + const SocketAddress& clientAddr, + std::chrono::steady_clock::time_point acceptTime) noexcept { + if (accConfig_.isSSL()) { + CHECK(sslCtxManager_); + AsyncSSLSocket::UniquePtr sslSock( + makeNewAsyncSSLSocket( + sslCtxManager_->getDefaultSSLCtx(), base_, fd)); + ++numPendingSSLConns_; + ++totalNumPendingSSLConns_; + if (totalNumPendingSSLConns_ > accConfig_.maxConcurrentSSLHandshakes) { + VLOG(2) << "dropped SSL handshake on " << accConfig_.name << + " too many handshakes in progress"; + updateSSLStats(sslSock.get(), std::chrono::milliseconds(0), + SSLErrorEnum::DROPPED); + sslConnectionError(); + return; + } + new AcceptorHandshakeHelper( + std::move(sslSock), this, clientAddr, acceptTime); + } else { + TransportInfo tinfo; + tinfo.ssl = false; + tinfo.acceptTime = acceptTime; + AsyncSocket::UniquePtr sock(makeNewAsyncSocket(base_, fd)); + connectionReady(std::move(sock), clientAddr, empty_string, tinfo); + } +} + +void +Acceptor::connectionReady( + AsyncSocket::UniquePtr sock, + const SocketAddress& clientAddr, + const string& nextProtocolName, + TransportInfo& tinfo) { + // Limit the number of reads from the socket per poll loop iteration, + // both to keep memory usage under control and to prevent one fast- + // writing client from starving other connections. + sock->setMaxReadsPerEvent(16); + tinfo.initWithSocket(sock.get()); + onNewConnection(std::move(sock), &clientAddr, nextProtocolName, tinfo); +} + +void +Acceptor::sslConnectionReady(AsyncSocket::UniquePtr sock, + const SocketAddress& clientAddr, + const string& nextProtocol, + TransportInfo& tinfo) { + CHECK(numPendingSSLConns_ > 0); + connectionReady(std::move(sock), clientAddr, nextProtocol, tinfo); + --numPendingSSLConns_; + --totalNumPendingSSLConns_; + if (state_ == State::kDraining) { + checkDrained(); + } +} + +void +Acceptor::sslConnectionError() { + CHECK(numPendingSSLConns_ > 0); + --numPendingSSLConns_; + --totalNumPendingSSLConns_; + if (state_ == State::kDraining) { + checkDrained(); + } +} + +void +Acceptor::acceptError(const std::exception& ex) noexcept { + // An error occurred. + // The most likely error is out of FDs. AsyncServerSocket will back off + // briefly if we are out of FDs, then continue accepting later. + // Just log a message here. + LOG(ERROR) << "error accepting on acceptor socket: " << ex.what(); +} + +void +Acceptor::acceptStopped() noexcept { + VLOG(3) << "Acceptor " << this << " acceptStopped()"; + // Drain the open client connections + drainAllConnections(); + + // If we haven't yet finished draining, begin doing so by marking ourselves + // as in the draining state. We must be sure to hit checkDrained() here, as + // if we're completely idle, we can should consider ourself drained + // immediately (as there is no outstanding work to complete to cause us to + // re-evaluate this). + if (state_ != State::kDone) { + state_ = State::kDraining; + checkDrained(); + } +} + +void +Acceptor::onEmpty(const ConnectionManager& cm) { + VLOG(3) << "Acceptor=" << this << " onEmpty()"; + if (state_ == State::kDraining) { + checkDrained(); + } +} + +void +Acceptor::checkDrained() { + CHECK(state_ == State::kDraining); + if (forceShutdownInProgress_ || + (downstreamConnectionManager_->getNumConnections() != 0) || + (numPendingSSLConns_ != 0)) { + return; + } + + VLOG(2) << "All connections drained from Acceptor=" << this << " in thread " + << base_; + + downstreamConnectionManager_.reset(); + + state_ = State::kDone; + + onConnectionsDrained(); +} + +milliseconds +Acceptor::getConnTimeout() const { + return accConfig_.connectionIdleTimeout; +} + +void Acceptor::addConnection(ManagedConnection* conn) { + // Add the socket to the timeout manager so that it can be cleaned + // up after being left idle for a long time. + downstreamConnectionManager_->addConnection(conn, true); +} + +void +Acceptor::forceStop() { + base_->runInEventBaseThread([&] { dropAllConnections(); }); +} + +void +Acceptor::dropAllConnections() { + if (downstreamConnectionManager_) { + VLOG(3) << "Dropping all connections from Acceptor=" << this << + " in thread " << base_; + assert(base_->isInEventBaseThread()); + forceShutdownInProgress_ = true; + downstreamConnectionManager_->dropAllConnections(); + CHECK(downstreamConnectionManager_->getNumConnections() == 0); + downstreamConnectionManager_.reset(); + } + CHECK(numPendingSSLConns_ == 0); + + state_ = State::kDone; + onConnectionsDrained(); +} + +} // namespace diff --git a/folly/wangle/acceptor/Acceptor.h b/folly/wangle/acceptor/Acceptor.h new file mode 100644 index 00000000..9c4563ba --- /dev/null +++ b/folly/wangle/acceptor/Acceptor.h @@ -0,0 +1,346 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include "folly/wangle/acceptor/ServerSocketConfig.h" +#include "folly/wangle/acceptor/ConnectionCounter.h" +#include +#include "folly/wangle/acceptor/LoadShedConfiguration.h" +#include "folly/wangle/ssl/SSLCacheProvider.h" +#include "folly/wangle/acceptor/TransportInfo.h" + +#include +#include +#include +#include + +namespace folly { namespace wangle { +class ManagedConnection; +}} + +namespace folly { + +class SocketAddress; +class SSLContext; +class AsyncTransport; +class SSLContextManager; + +/** + * An abstract acceptor for TCP-based network services. + * + * There is one acceptor object per thread for each listening socket. When a + * new connection arrives on the listening socket, it is accepted by one of the + * acceptor objects. From that point on the connection will be processed by + * that acceptor's thread. + * + * The acceptor will call the abstract onNewConnection() method to create + * a new ManagedConnection object for each accepted socket. The acceptor + * also tracks all outstanding connections that it has accepted. + */ +class Acceptor : + public folly::AsyncServerSocket::AcceptCallback, + public folly::wangle::ConnectionManager::Callback { + public: + + enum class State : uint32_t { + kInit, // not yet started + kRunning, // processing requests normally + kDraining, // processing outstanding conns, but not accepting new ones + kDone, // no longer accepting, and all connections finished + }; + + explicit Acceptor(const ServerSocketConfig& accConfig); + virtual ~Acceptor(); + + /** + * Supply an SSL cache provider + * @note Call this before init() + */ + virtual void setSSLCacheProvider( + const std::shared_ptr& cacheProvider) { + cacheProvider_ = cacheProvider; + } + + /** + * Initialize the Acceptor to run in the specified EventBase + * thread, receiving connections from the specified AsyncServerSocket. + * + * This method will be called from the AsyncServerSocket's primary thread, + * not the specified EventBase thread. + */ + virtual void init(AsyncServerSocket* serverSocket, + EventBase* eventBase); + + /** + * Dynamically add a new SSLContextConfig + */ + void addSSLContextConfig(const SSLContextConfig& sslCtxConfig); + + SSLContextManager* getSSLContextManager() const { + return sslCtxManager_.get(); + } + + /** + * Return the number of outstanding connections in this service instance. + */ + uint32_t getNumConnections() const { + return downstreamConnectionManager_ ? + downstreamConnectionManager_->getNumConnections() : 0; + } + + /** + * Access the Acceptor's event base. + */ + virtual EventBase* getEventBase() const { return base_; } + + /** + * Access the Acceptor's downstream (client-side) ConnectionManager + */ + virtual folly::wangle::ConnectionManager* getConnectionManager() { + return downstreamConnectionManager_.get(); + } + + /** + * Invoked when a new ManagedConnection is created. + * + * This allows the Acceptor to track the outstanding connections, + * for tracking timeouts and for ensuring that all connections have been + * drained on shutdown. + */ + void addConnection(folly::wangle::ManagedConnection* connection); + + /** + * Get this acceptor's current state. + */ + State getState() const { + return state_; + } + + /** + * Get the current connection timeout. + */ + std::chrono::milliseconds getConnTimeout() const; + + /** + * Returns the name of this VIP. + * + * Will return an empty string if no name has been configured. + */ + const std::string& getName() const { + return accConfig_.name; + } + + /** + * Force the acceptor to drop all connections and stop processing. + * + * This function may be called from any thread. The acceptor will not + * necessarily stop before this function returns: the stop will be scheduled + * to run in the acceptor's thread. + */ + virtual void forceStop(); + + bool isSSL() const { return accConfig_.isSSL(); } + + const ServerSocketConfig& getConfig() const { return accConfig_; } + + static uint64_t getTotalNumPendingSSLConns() { + return totalNumPendingSSLConns_.load(); + } + + /** + * Called right when the TCP connection has been accepted, before processing + * the first HTTP bytes (HTTP) or the SSL handshake (HTTPS) + */ + virtual void onDoneAcceptingConnection( + int fd, + const SocketAddress& clientAddr, + std::chrono::steady_clock::time_point acceptTime + ) noexcept; + + /** + * Begins either processing HTTP bytes (HTTP) or the SSL handshake (HTTPS) + */ + void processEstablishedConnection( + int fd, + const SocketAddress& clientAddr, + std::chrono::steady_clock::time_point acceptTime + ) noexcept; + + /** + * Drains all open connections of their outstanding transactions. When + * a connection's transaction count reaches zero, the connection closes. + */ + void drainAllConnections(); + + protected: + friend class AcceptorHandshakeHelper; + + /** + * Our event loop. + * + * Probably needs to be used to pass to a ManagedConnection + * implementation. Also visible in case a subclass wishes to do additional + * things w/ the event loop (e.g. in attach()). + */ + EventBase* base_{nullptr}; + + virtual uint64_t getConnectionCountForLoadShedding(void) const { return 0; } + + /** + * Hook for subclasses to drop newly accepted connections prior + * to handshaking. + */ + virtual bool canAccept(const folly::SocketAddress&); + + /** + * Invoked when a new connection is created. This is where application starts + * processing a new downstream connection. + * + * NOTE: Application should add the new connection to + * downstreamConnectionManager so that it can be garbage collected after + * certain period of idleness. + * + * @param sock the socket connected to the client + * @param address the address of the client + * @param nextProtocolName the name of the L6 or L7 protocol to be + * spoken on the connection, if known (e.g., + * from TLS NPN during secure connection setup), + * or an empty string if unknown + */ + virtual void onNewConnection( + AsyncSocket::UniquePtr sock, + const folly::SocketAddress* address, + const std::string& nextProtocolName, + const TransportInfo& tinfo) = 0; + + virtual AsyncSocket::UniquePtr makeNewAsyncSocket(EventBase* base, int fd) { + return AsyncSocket::UniquePtr(new AsyncSocket(base, fd)); + } + + virtual AsyncSSLSocket::UniquePtr makeNewAsyncSSLSocket( + const std::shared_ptr& ctx, EventBase* base, int fd) { + return AsyncSSLSocket::UniquePtr(new AsyncSSLSocket(ctx, base, fd)); + } + + /** + * Hook for subclasses to record stats about SSL connection establishment. + */ + virtual void updateSSLStats( + const AsyncSSLSocket* sock, + std::chrono::milliseconds acceptLatency, + SSLErrorEnum error) noexcept {} + + /** + * Drop all connections. + * + * forceStop() schedules dropAllConnections() to be called in the acceptor's + * thread. + */ + void dropAllConnections(); + + protected: + + /** + * onConnectionsDrained() will be called once all connections have been + * drained while the acceptor is stopping. + * + * Subclasses can override this method to perform any subclass-specific + * cleanup. + */ + virtual void onConnectionsDrained() {} + + // AsyncServerSocket::AcceptCallback methods + void connectionAccepted(int fd, + const folly::SocketAddress& clientAddr) + noexcept; + void acceptError(const std::exception& ex) noexcept; + void acceptStopped() noexcept; + + // ConnectionManager::Callback methods + void onEmpty(const folly::wangle::ConnectionManager& cm); + void onConnectionAdded(const folly::wangle::ConnectionManager& cm) {} + void onConnectionRemoved(const folly::wangle::ConnectionManager& cm) {} + + /** + * Process a connection that is to ready to receive L7 traffic. + * This method is called immediately upon accept for plaintext + * connections and upon completion of SSL handshaking or resumption + * for SSL connections. + */ + void connectionReady( + AsyncSocket::UniquePtr sock, + const folly::SocketAddress& clientAddr, + const std::string& nextProtocolName, + TransportInfo& tinfo); + + const LoadShedConfiguration& getLoadShedConfiguration() const { + return loadShedConfig_; + } + + protected: + const ServerSocketConfig accConfig_; + void setLoadShedConfig(const LoadShedConfiguration& from, + IConnectionCounter* counter); + + /** + * Socket options to apply to the client socket + */ + AsyncSocket::OptionMap socketOptions_; + + std::unique_ptr sslCtxManager_; + + /** + * Whether we want to enable client hello parsing in the handshake helper + * to get list of supported client ciphers. + */ + bool parseClientHello_{false}; + + folly::wangle::ConnectionManager::UniquePtr downstreamConnectionManager_; + + private: + + // Forbidden copy constructor and assignment opererator + Acceptor(Acceptor const &) = delete; + Acceptor& operator=(Acceptor const &) = delete; + + /** + * Wrapper for connectionReady() that decrements the count of + * pending SSL connections. + */ + void sslConnectionReady(AsyncSocket::UniquePtr sock, + const folly::SocketAddress& clientAddr, + const std::string& nextProtocol, + TransportInfo& tinfo); + + /** + * Notification callback for SSL handshake failures. + */ + void sslConnectionError(); + + void checkDrained(); + + State state_{State::kInit}; + uint64_t numPendingSSLConns_{0}; + + static std::atomic totalNumPendingSSLConns_; + + bool forceShutdownInProgress_{false}; + LoadShedConfiguration loadShedConfig_; + IConnectionCounter* connectionCounter_{nullptr}; + std::shared_ptr cacheProvider_; +}; + +class AcceptorFactory { + public: + virtual std::shared_ptr newAcceptor() = 0; + virtual ~AcceptorFactory() = default; +}; + +} // namespace diff --git a/folly/wangle/acceptor/ConnectionCounter.h b/folly/wangle/acceptor/ConnectionCounter.h new file mode 100644 index 00000000..bf891bb2 --- /dev/null +++ b/folly/wangle/acceptor/ConnectionCounter.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +namespace folly { + +class IConnectionCounter { + public: + virtual uint64_t getNumConnections() const = 0; + + /** + * Get the maximum number of non-whitelisted client-side connections + * across all Acceptors managed by this. A value + * of zero means "unlimited." + */ + virtual uint64_t getMaxConnections() const = 0; + + /** + * Increment the count of client-side connections. + */ + virtual void onConnectionAdded() = 0; + + /** + * Decrement the count of client-side connections. + */ + virtual void onConnectionRemoved() = 0; + virtual ~IConnectionCounter() {} +}; + +class SimpleConnectionCounter: public IConnectionCounter { + public: + uint64_t getNumConnections() const override { return numConnections_; } + uint64_t getMaxConnections() const override { return maxConnections_; } + void setMaxConnections(uint64_t maxConnections) { + maxConnections_ = maxConnections; + } + + void onConnectionAdded() override { numConnections_++; } + void onConnectionRemoved() override { numConnections_--; } + virtual ~SimpleConnectionCounter() {} + + protected: + uint64_t maxConnections_{0}; + uint64_t numConnections_{0}; +}; + +} diff --git a/folly/wangle/acceptor/ConnectionManager.cpp b/folly/wangle/acceptor/ConnectionManager.cpp new file mode 100644 index 00000000..72b1492f --- /dev/null +++ b/folly/wangle/acceptor/ConnectionManager.cpp @@ -0,0 +1,175 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +using folly::HHWheelTimer; +using std::chrono::milliseconds; + +namespace folly { namespace wangle { + +ConnectionManager::ConnectionManager(EventBase* eventBase, + milliseconds timeout, Callback* callback) + : connTimeouts_(new HHWheelTimer(eventBase)), + callback_(callback), + eventBase_(eventBase), + idleIterator_(conns_.end()), + idleLoopCallback_(this), + timeout_(timeout) { + +} + +void +ConnectionManager::addConnection(ManagedConnection* connection, + bool timeout) { + CHECK_NOTNULL(connection); + ConnectionManager* oldMgr = connection->getConnectionManager(); + if (oldMgr != this) { + if (oldMgr) { + // 'connection' was being previously managed in a different thread. + // We must remove it from that manager before adding it to this one. + oldMgr->removeConnection(connection); + } + conns_.push_back(*connection); + connection->setConnectionManager(this); + if (callback_) { + callback_->onConnectionAdded(*this); + } + } + if (timeout) { + scheduleTimeout(connection); + } +} + +void +ConnectionManager::scheduleTimeout(ManagedConnection* connection) { + if (timeout_ > std::chrono::milliseconds(0)) { + connTimeouts_->scheduleTimeout(connection, timeout_); + } +} + +void ConnectionManager::scheduleTimeout( + folly::HHWheelTimer::Callback* callback, + std::chrono::milliseconds timeout) { + connTimeouts_->scheduleTimeout(callback, timeout); +} + +void +ConnectionManager::removeConnection(ManagedConnection* connection) { + if (connection->getConnectionManager() == this) { + connection->cancelTimeout(); + connection->setConnectionManager(nullptr); + + // Un-link the connection from our list, being careful to keep the iterator + // that we're using for idle shedding valid + auto it = conns_.iterator_to(*connection); + if (it == idleIterator_) { + ++idleIterator_; + } + conns_.erase(it); + + if (callback_) { + callback_->onConnectionRemoved(*this); + if (getNumConnections() == 0) { + callback_->onEmpty(*this); + } + } + } +} + +void +ConnectionManager::initiateGracefulShutdown( + std::chrono::milliseconds idleGrace) { + if (idleGrace.count() > 0) { + idleLoopCallback_.scheduleTimeout(idleGrace); + VLOG(3) << "Scheduling idle grace period of " << idleGrace.count() << "ms"; + } else { + action_ = ShutdownAction::DRAIN2; + VLOG(3) << "proceeding directly to closing idle connections"; + } + drainAllConnections(); +} + +void +ConnectionManager::drainAllConnections() { + DestructorGuard g(this); + size_t numCleared = 0; + size_t numKept = 0; + + auto it = idleIterator_ == conns_.end() ? + conns_.begin() : idleIterator_; + + while (it != conns_.end() && (numKept + numCleared) < 64) { + ManagedConnection& conn = *it++; + if (action_ == ShutdownAction::DRAIN1) { + conn.notifyPendingShutdown(); + } else { + // Second time around: close idle sessions. If they aren't idle yet, + // have them close when they are idle + if (conn.isBusy()) { + numKept++; + } else { + numCleared++; + } + conn.closeWhenIdle(); + } + } + + if (action_ == ShutdownAction::DRAIN2) { + VLOG(2) << "Idle connections cleared: " << numCleared << + ", busy conns kept: " << numKept; + } + if (it != conns_.end()) { + idleIterator_ = it; + eventBase_->runInLoop(&idleLoopCallback_); + } else { + action_ = ShutdownAction::DRAIN2; + } +} + +void +ConnectionManager::dropAllConnections() { + DestructorGuard g(this); + + // Iterate through our connection list, and drop each connection. + VLOG(3) << "connections to drop: " << conns_.size(); + idleLoopCallback_.cancelTimeout(); + unsigned i = 0; + while (!conns_.empty()) { + ManagedConnection& conn = conns_.front(); + conns_.pop_front(); + conn.cancelTimeout(); + conn.setConnectionManager(nullptr); + // For debugging purposes, dump information about the first few + // connections. + static const unsigned MAX_CONNS_TO_DUMP = 2; + if (++i <= MAX_CONNS_TO_DUMP) { + conn.dumpConnectionState(3); + } + conn.dropConnection(); + } + idleIterator_ = conns_.end(); + idleLoopCallback_.cancelLoopCallback(); + + if (callback_) { + callback_->onEmpty(*this); + } +} + +}} // folly::wangle diff --git a/folly/wangle/acceptor/ConnectionManager.h b/folly/wangle/acceptor/ConnectionManager.h new file mode 100644 index 00000000..2b0daa8d --- /dev/null +++ b/folly/wangle/acceptor/ConnectionManager.h @@ -0,0 +1,200 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace folly { namespace wangle { + +/** + * A ConnectionManager keeps track of ManagedConnections. + */ +class ConnectionManager: public folly::DelayedDestruction { + public: + + /** + * Interface for an optional observer that's notified about + * various events in a ConnectionManager + */ + class Callback { + public: + virtual ~Callback() {} + + /** + * Invoked when the number of connections managed by the + * ConnectionManager changes from nonzero to zero. + */ + virtual void onEmpty(const ConnectionManager& cm) = 0; + + /** + * Invoked when a connection is added to the ConnectionManager. + */ + virtual void onConnectionAdded(const ConnectionManager& cm) = 0; + + /** + * Invoked when a connection is removed from the ConnectionManager. + */ + virtual void onConnectionRemoved(const ConnectionManager& cm) = 0; + }; + + typedef std::unique_ptr UniquePtr; + + /** + * Returns a new instance of ConnectionManager wrapped in a unique_ptr + */ + template + static UniquePtr makeUnique(Args&&... args) { + return folly::make_unique( + std::forward(args)...); + } + + /** + * Constructor not to be used by itself. + */ + ConnectionManager(folly::EventBase* eventBase, + std::chrono::milliseconds timeout, + Callback* callback = nullptr); + + /** + * Add a connection to the set of connections managed by this + * ConnectionManager. + * + * @param connection The connection to add. + * @param timeout Whether to immediately register this connection + * for an idle timeout callback. + */ + void addConnection(ManagedConnection* connection, + bool timeout = false); + + /** + * Schedule a timeout callback for a connection. + */ + void scheduleTimeout(ManagedConnection* connection); + + /* + * Schedule a callback on the wheel timer + */ + void scheduleTimeout(folly::HHWheelTimer::Callback* callback, + std::chrono::milliseconds timeout); + + /** + * Remove a connection from this ConnectionManager and, if + * applicable, cancel the pending timeout callback that the + * ConnectionManager has scheduled for the connection. + * + * @note This method does NOT destroy the connection. + */ + void removeConnection(ManagedConnection* connection); + + /* Begin gracefully shutting down connections in this ConnectionManager. + * Notify all connections of pending shutdown, and after idleGrace, + * begin closing idle connections. + */ + void initiateGracefulShutdown(std::chrono::milliseconds idleGrace); + + /** + * Destroy all connections Managed by this ConnectionManager, even + * the ones that are busy. + */ + void dropAllConnections(); + + size_t getNumConnections() const { return conns_.size(); } + + template + void iterateConns(F func) { + auto it = conns_.begin(); + while ( it != conns_.end()) { + func(&(*it)); + it++; + } + } + + private: + class CloseIdleConnsCallback : + public folly::EventBase::LoopCallback, + public folly::AsyncTimeout { + public: + explicit CloseIdleConnsCallback(ConnectionManager* manager) + : folly::AsyncTimeout(manager->eventBase_), + manager_(manager) {} + + void runLoopCallback() noexcept override { + VLOG(3) << "Draining more conns from loop callback"; + manager_->drainAllConnections(); + } + + void timeoutExpired() noexcept override { + VLOG(3) << "Idle grace expired"; + manager_->drainAllConnections(); + } + + private: + ConnectionManager* manager_; + }; + + enum class ShutdownAction : uint8_t { + /** + * Drain part 1: inform remote that you will soon reject new requests. + */ + DRAIN1 = 0, + /** + * Drain part 2: start rejecting new requests. + */ + DRAIN2 = 1, + }; + + ~ConnectionManager() {} + + ConnectionManager(const ConnectionManager&) = delete; + ConnectionManager& operator=(ConnectionManager&) = delete; + + /** + * Destroy all connections managed by this ConnectionManager that + * are currently idle, as determined by a call to each ManagedConnection's + * isBusy() method. + */ + void drainAllConnections(); + + /** All connections */ + folly::CountedIntrusiveList< + ManagedConnection,&ManagedConnection::listHook_> conns_; + + /** Connections that currently are registered for timeouts */ + folly::HHWheelTimer::UniquePtr connTimeouts_; + + /** Optional callback to notify of state changes */ + Callback* callback_; + + /** Event base in which we run */ + folly::EventBase* eventBase_; + + /** Iterator to the next connection to shed; used by drainAllConnections() */ + folly::CountedIntrusiveList< + ManagedConnection,&ManagedConnection::listHook_>::iterator idleIterator_; + CloseIdleConnsCallback idleLoopCallback_; + ShutdownAction action_{ShutdownAction::DRAIN1}; + std::chrono::milliseconds timeout_; +}; + +}} // folly::wangle diff --git a/folly/wangle/acceptor/DomainNameMisc.h b/folly/wangle/acceptor/DomainNameMisc.h new file mode 100644 index 00000000..41c4c741 --- /dev/null +++ b/folly/wangle/acceptor/DomainNameMisc.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include + +namespace folly { + +struct dn_char_traits : public std::char_traits { + static bool eq(char c1, char c2) { + return ::tolower(c1) == ::tolower(c2); + } + + static bool ne(char c1, char c2) { + return ::tolower(c1) != ::tolower(c2); + } + + static bool lt(char c1, char c2) { + return ::tolower(c1) < ::tolower(c2); + } + + static int compare(const char* s1, const char* s2, size_t n) { + while (n--) { + if(::tolower(*s1) < ::tolower(*s2) ) { + return -1; + } + if(::tolower(*s1) > ::tolower(*s2) ) { + return 1; + } + ++s1; + ++s2; + } + return 0; + } + + static const char* find(const char* s, size_t n, char a) { + char la = ::tolower(a); + while (n--) { + if(::tolower(*s) == la) { + return s; + } else { + ++s; + } + } + return nullptr; + } +}; + +// Case insensitive string +typedef std::basic_string DNString; + +struct DNStringHash : public std::hash { + size_t operator()(const DNString& s) const noexcept { + size_t h = static_cast(0xc70f6907UL); + const char* d = s.data(); + for (size_t i = 0; i < s.length(); ++i) { + char a = ::tolower(*d++); + h = std::_Hash_impl::hash(&a, sizeof(a), h); + } + return h; + } +}; + +} // namespace diff --git a/folly/wangle/acceptor/LoadShedConfiguration.cpp b/folly/wangle/acceptor/LoadShedConfiguration.cpp new file mode 100644 index 00000000..191a7dce --- /dev/null +++ b/folly/wangle/acceptor/LoadShedConfiguration.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include + +#include +#include + +using std::string; + +namespace folly { + +void LoadShedConfiguration::addWhitelistAddr(folly::StringPiece input) { + auto addr = input.str(); + size_t separator = addr.find_first_of('/'); + if (separator == string::npos) { + whitelistAddrs_.insert(SocketAddress(addr, 0)); + } else { + unsigned prefixLen = folly::to(addr.substr(separator + 1)); + addr.erase(separator); + whitelistNetworks_.insert(NetworkAddress(SocketAddress(addr, 0), prefixLen)); + } +} + +bool LoadShedConfiguration::isWhitelisted(const SocketAddress& address) const { + if (whitelistAddrs_.find(address) != whitelistAddrs_.end()) { + return true; + } + for (auto& network : whitelistNetworks_) { + if (network.contains(address)) { + return true; + } + } + return false; +} + +} diff --git a/folly/wangle/acceptor/LoadShedConfiguration.h b/folly/wangle/acceptor/LoadShedConfiguration.h new file mode 100644 index 00000000..97e32027 --- /dev/null +++ b/folly/wangle/acceptor/LoadShedConfiguration.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace folly { + +/** + * Class that holds an LoadShed configuration for a service + */ +class LoadShedConfiguration { + public: + + // Comparison function for SocketAddress that disregards the port + struct AddressOnlyCompare { + bool operator()( + const SocketAddress& addr1, + const SocketAddress& addr2) const { + return addr1.getIPAddress() < addr2.getIPAddress(); + } + }; + + typedef std::set AddressSet; + typedef std::set NetworkSet; + + LoadShedConfiguration() {} + + ~LoadShedConfiguration() {} + + void addWhitelistAddr(folly::StringPiece); + + /** + * Set/get the set of IPs that should be whitelisted through even when we're + * trying to shed load. + */ + void setWhitelistAddrs(const AddressSet& addrs) { whitelistAddrs_ = addrs; } + const AddressSet& getWhitelistAddrs() const { return whitelistAddrs_; } + + /** + * Set/get the set of networks that should be whitelisted through even + * when we're trying to shed load. + */ + void setWhitelistNetworks(const NetworkSet& networks) { + whitelistNetworks_ = networks; + } + const NetworkSet& getWhitelistNetworks() const { return whitelistNetworks_; } + + /** + * Set/get the maximum number of downstream connections across all VIPs. + */ + void setMaxConnections(uint64_t maxConns) { maxConnections_ = maxConns; } + uint64_t getMaxConnections() const { return maxConnections_; } + + /** + * Set/get the maximum cpu usage. + */ + void setMaxMemUsage(double max) { + CHECK(max >= 0); + CHECK(max <= 1); + maxMemUsage_ = max; + } + double getMaxMemUsage() const { return maxMemUsage_; } + + /** + * Set/get the maximum memory usage. + */ + void setMaxCpuUsage(double max) { + CHECK(max >= 0); + CHECK(max <= 1); + maxCpuUsage_ = max; + } + double getMaxCpuUsage() const { return maxCpuUsage_; } + + void setLoadUpdatePeriod(std::chrono::milliseconds period) { + period_ = period; + } + std::chrono::milliseconds getLoadUpdatePeriod() const { return period_; } + + bool isWhitelisted(const SocketAddress& addr) const; + + private: + + AddressSet whitelistAddrs_; + NetworkSet whitelistNetworks_; + uint64_t maxConnections_{0}; + double maxMemUsage_; + double maxCpuUsage_; + std::chrono::milliseconds period_; +}; + +} diff --git a/folly/wangle/acceptor/ManagedConnection.cpp b/folly/wangle/acceptor/ManagedConnection.cpp new file mode 100644 index 00000000..9011d598 --- /dev/null +++ b/folly/wangle/acceptor/ManagedConnection.cpp @@ -0,0 +1,57 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +namespace folly { namespace wangle { + +ManagedConnection::ManagedConnection() + : connectionManager_(nullptr) { +} + +ManagedConnection::~ManagedConnection() { + if (connectionManager_) { + connectionManager_->removeConnection(this); + } +} + +void +ManagedConnection::resetTimeout() { + if (connectionManager_) { + connectionManager_->scheduleTimeout(this); + } +} + +void +ManagedConnection::scheduleTimeout( + folly::HHWheelTimer::Callback* callback, + std::chrono::milliseconds timeout) { + if (connectionManager_) { + connectionManager_->scheduleTimeout(callback, timeout); + } +} + +////////////////////// Globals ///////////////////// + +std::ostream& +operator<<(std::ostream& os, const ManagedConnection& conn) { + conn.describe(os); + return os; +} + +}} // folly::wangle diff --git a/folly/wangle/acceptor/ManagedConnection.h b/folly/wangle/acceptor/ManagedConnection.h new file mode 100644 index 00000000..50e7c057 --- /dev/null +++ b/folly/wangle/acceptor/ManagedConnection.h @@ -0,0 +1,115 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace folly { namespace wangle { + +class ConnectionManager; + +/** + * Interface describing a connection that can be managed by a + * container such as an Acceptor. + */ +class ManagedConnection: + public folly::HHWheelTimer::Callback, + public folly::DelayedDestruction { + public: + + ManagedConnection(); + + // HHWheelTimer::Callback API (left for subclasses to implement). + virtual void timeoutExpired() noexcept = 0; + + /** + * Print a human-readable description of the connection. + * @param os Destination stream. + */ + virtual void describe(std::ostream& os) const = 0; + + /** + * Check whether the connection has any requests outstanding. + */ + virtual bool isBusy() const = 0; + + /** + * Notify the connection that a shutdown is pending. This method will be + * called at the beginning of graceful shutdown. + */ + virtual void notifyPendingShutdown() = 0; + + /** + * Instruct the connection that it should shutdown as soon as it is + * safe. This is called after notifyPendingShutdown(). + */ + virtual void closeWhenIdle() = 0; + + /** + * Forcibly drop a connection. + * + * If a request is in progress, this should cause the connection to be + * closed with a reset. + */ + virtual void dropConnection() = 0; + + /** + * Dump the state of the connection to the log + */ + virtual void dumpConnectionState(uint8_t loglevel) = 0; + + /** + * If the connection has a connection manager, reset the timeout + * countdown. + * @note If the connection manager doesn't have the connection scheduled + * for a timeout already, this method will schedule one. If the + * connection manager does have the connection connection scheduled + * for a timeout, this method will push back the timeout to N msec + * from now, where N is the connection manager's timer interval. + */ + virtual void resetTimeout(); + + // Schedule an arbitrary timeout on the HHWheelTimer + virtual void scheduleTimeout( + folly::HHWheelTimer::Callback* callback, + std::chrono::milliseconds timeout); + + ConnectionManager* getConnectionManager() { + return connectionManager_; + } + + protected: + virtual ~ManagedConnection(); + + private: + friend class ConnectionManager; + + void setConnectionManager(ConnectionManager* mgr) { + connectionManager_ = mgr; + } + + ConnectionManager* connectionManager_; + + folly::SafeIntrusiveListHook listHook_; +}; + +std::ostream& operator<<(std::ostream& os, const ManagedConnection& conn); + +}} // folly::wangle diff --git a/folly/wangle/acceptor/NetworkAddress.h b/folly/wangle/acceptor/NetworkAddress.h new file mode 100644 index 00000000..36980371 --- /dev/null +++ b/folly/wangle/acceptor/NetworkAddress.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include + +namespace folly { + +/** + * A simple wrapper around SocketAddress that represents + * a network in CIDR notation + */ +class NetworkAddress { +public: + /** + * Create a NetworkAddress for an addr/prefixLen + * @param addr IPv4 or IPv6 address of the network + * @param prefixLen Prefix length, in bits + */ + NetworkAddress(const folly::SocketAddress& addr, + unsigned prefixLen): + addr_(addr), prefixLen_(prefixLen) {} + + /** Get the network address */ + const folly::SocketAddress& getAddress() const { + return addr_; + } + + /** Get the prefix length in bits */ + unsigned getPrefixLength() const { return prefixLen_; } + + /** Check whether a given address lies within the network */ + bool contains(const folly::SocketAddress& addr) const { + return addr_.prefixMatch(addr, prefixLen_); + } + + /** Comparison operator to enable use in ordered collections */ + bool operator<(const NetworkAddress& other) const { + if (addr_ < other.addr_) { + return true; + } else if (other.addr_ < addr_) { + return false; + } else { + return (prefixLen_ < other.prefixLen_); + } + } + +private: + folly::SocketAddress addr_; + unsigned prefixLen_; +}; + +} // namespace diff --git a/folly/wangle/acceptor/ServerSocketConfig.h b/folly/wangle/acceptor/ServerSocketConfig.h new file mode 100644 index 00000000..3a722653 --- /dev/null +++ b/folly/wangle/acceptor/ServerSocketConfig.h @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace folly { + +/** + * Configuration for a single Acceptor. + * + * This configures not only accept behavior, but also some types of SSL + * behavior that may make sense to configure on a per-VIP basis (e.g. which + * cert(s) we use, etc). + */ +struct ServerSocketConfig { + ServerSocketConfig() { + // generate a single random current seed + uint8_t seed[32]; + folly::Random::secureRandom(seed, sizeof(seed)); + initialTicketSeeds.currentSeeds.push_back( + SSLUtil::hexlify(std::string((char *)seed, sizeof(seed)))); + } + + bool isSSL() const { return !(sslContextConfigs.empty()); } + + /** + * Set/get the socket options to apply on all downstream connections. + */ + void setSocketOptions( + const AsyncSocket::OptionMap& opts) { + socketOptions_ = filterIPSocketOptions(opts, bindAddress.getFamily()); + } + AsyncSocket::OptionMap& + getSocketOptions() { + return socketOptions_; + } + const AsyncSocket::OptionMap& + getSocketOptions() const { + return socketOptions_; + } + + bool hasExternalPrivateKey() const { + for (const auto& cfg : sslContextConfigs) { + if (!cfg.isLocalPrivateKey) { + return true; + } + } + return false; + } + + /** + * The name of this acceptor; used for stats/reporting purposes. + */ + std::string name; + + /** + * The depth of the accept queue backlog. + */ + uint32_t acceptBacklog{1024}; + + /** + * The number of milliseconds a connection can be idle before we close it. + */ + std::chrono::milliseconds connectionIdleTimeout{600000}; + + /** + * The address to bind to. + */ + SocketAddress bindAddress; + + /** + * Options for controlling the SSL cache. + */ + SSLCacheOptions sslCacheOptions{std::chrono::seconds(600), 20480, 200}; + + /** + * The initial TLS ticket seeds. + */ + TLSTicketKeySeeds initialTicketSeeds; + + /** + * The configs for all the SSL_CTX for use by this Acceptor. + */ + std::vector sslContextConfigs; + + /** + * Determines if the Acceptor does strict checking when loading the SSL + * contexts. + */ + bool strictSSL{true}; + + /** + * Maximum number of concurrent pending SSL handshakes + */ + uint32_t maxConcurrentSSLHandshakes{30720}; + + private: + AsyncSocket::OptionMap socketOptions_; +}; + +} // folly diff --git a/folly/wangle/acceptor/SocketOptions.cpp b/folly/wangle/acceptor/SocketOptions.cpp new file mode 100644 index 00000000..3a159b37 --- /dev/null +++ b/folly/wangle/acceptor/SocketOptions.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include + +#include +#include + +namespace folly { + +AsyncSocket::OptionMap filterIPSocketOptions( + const AsyncSocket::OptionMap& allOptions, + const int addrFamily) { + AsyncSocket::OptionMap opts; + int exclude; + if (addrFamily == AF_INET) { + exclude = IPPROTO_IPV6; + } else if (addrFamily == AF_INET6) { + exclude = IPPROTO_IP; + } else { + LOG(FATAL) << "Address family " << addrFamily << " was not IPv4 or IPv6"; + return opts; + } + for (const auto& opt: allOptions) { + if (opt.first.level != exclude) { + opts[opt.first] = opt.second; + } + } + return opts; +} + +} diff --git a/folly/wangle/acceptor/SocketOptions.h b/folly/wangle/acceptor/SocketOptions.h new file mode 100644 index 00000000..37ba3711 --- /dev/null +++ b/folly/wangle/acceptor/SocketOptions.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include + +namespace folly { + +/** + * Returns a copy of the socket options excluding options with the given + * level. + */ +AsyncSocket::OptionMap filterIPSocketOptions( + const AsyncSocket::OptionMap& allOptions, + const int addrFamily); + +} diff --git a/folly/wangle/acceptor/TransportInfo.cpp b/folly/wangle/acceptor/TransportInfo.cpp new file mode 100644 index 00000000..0f063b7c --- /dev/null +++ b/folly/wangle/acceptor/TransportInfo.cpp @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include + +#include +#include +#include + +using std::chrono::microseconds; +using std::map; +using std::string; + +namespace folly { + +bool TransportInfo::initWithSocket(const AsyncSocket* sock) { +#if defined(__linux__) || defined(__FreeBSD__) + if (!TransportInfo::readTcpInfo(&tcpinfo, sock)) { + tcpinfoErrno = errno; + return false; + } + rtt = microseconds(tcpinfo.tcpi_rtt); + validTcpinfo = true; +#else + tcpinfoErrno = EINVAL; + rtt = microseconds(-1); +#endif + return true; +} + +int64_t TransportInfo::readRTT(const AsyncSocket* sock) { +#if defined(__linux__) || defined(__FreeBSD__) + struct tcp_info tcpinfo; + if (!TransportInfo::readTcpInfo(&tcpinfo, sock)) { + return -1; + } + return tcpinfo.tcpi_rtt; +#else + return -1; +#endif +} + +#if defined(__linux__) || defined(__FreeBSD__) +bool TransportInfo::readTcpInfo(struct tcp_info* tcpinfo, + const AsyncSocket* sock) { + socklen_t len = sizeof(struct tcp_info); + if (!sock) { + return false; + } + if (getsockopt(sock->getFd(), IPPROTO_TCP, + TCP_INFO, (void*) tcpinfo, &len) < 0) { + VLOG(4) << "Error calling getsockopt(): " << strerror(errno); + return false; + } + return true; +} +#endif + +} // folly diff --git a/folly/wangle/acceptor/TransportInfo.h b/folly/wangle/acceptor/TransportInfo.h new file mode 100644 index 00000000..e11021b1 --- /dev/null +++ b/folly/wangle/acceptor/TransportInfo.h @@ -0,0 +1,292 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include + +#include +#include +#include + +namespace folly { +class AsyncSocket; + +/** + * A structure that encapsulates byte counters related to the HTTP headers. + */ +struct HTTPHeaderSize { + /** + * The number of bytes used to represent the header after compression or + * before decompression. If header compression is not supported, the value + * is set to 0. + */ + size_t compressed{0}; + + /** + * The number of bytes used to represent the serialized header before + * compression or after decompression, in plain-text format. + */ + size_t uncompressed{0}; +}; + +struct TransportInfo { + /* + * timestamp of when the connection handshake was completed + */ + std::chrono::steady_clock::time_point acceptTime{}; + + /* + * connection RTT (Round-Trip Time) + */ + std::chrono::microseconds rtt{0}; + +#if defined(__linux__) || defined(__FreeBSD__) + /* + * TCP information as fetched from getsockopt(2) + */ + tcp_info tcpinfo { +#if __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 17 + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 // 32 +#else + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 // 29 +#endif // __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 17 + }; +#endif // defined(__linux__) || defined(__FreeBSD__) + + /* + * time for setting the connection, from the moment in was accepted until it + * is established. + */ + std::chrono::milliseconds setupTime{0}; + + /* + * time for setting up the SSL connection or SSL handshake + */ + std::chrono::milliseconds sslSetupTime{0}; + + /* + * The name of the SSL ciphersuite used by the transaction's + * transport. Returns null if the transport is not SSL. + */ + const char* sslCipher{nullptr}; + + /* + * The SSL server name used by the transaction's + * transport. Returns null if the transport is not SSL. + */ + const char* sslServerName{nullptr}; + + /* + * list of ciphers sent by the client + */ + std::string sslClientCiphers{}; + + /* + * list of compression methods sent by the client + */ + std::string sslClientComprMethods{}; + + /* + * list of TLS extensions sent by the client + */ + std::string sslClientExts{}; + + /* + * hash of all the SSL parameters sent by the client + */ + std::string sslSignature{}; + + /* + * list of ciphers supported by the server + */ + std::string sslServerCiphers{}; + + /* + * guessed "(os) (browser)" based on SSL Signature + */ + std::string guessedUserAgent{}; + + /** + * The result of SSL NPN negotiation. + */ + std::string sslNextProtocol{}; + + /* + * total number of bytes sent over the connection + */ + int64_t totalBytes{0}; + + /** + * header bytes read + */ + HTTPHeaderSize ingressHeader; + + /* + * header bytes written + */ + HTTPHeaderSize egressHeader; + + /* + * Here is how the timeToXXXByte variables are planned out: + * 1. All timeToXXXByte variables are measuring the ByteEvent from reqStart_ + * 2. You can get the timing between two ByteEvents by calculating their + * differences. For example: + * timeToLastBodyByteAck - timeToFirstByte + * => Total time to deliver the body + * 3. The calculation in point (2) is typically done outside acceptor + * + * Future plan: + * We should log the timestamps (TimePoints) and allow + * the consumer to calculate the latency whatever it + * wants instead of calculating them in wangle, for the sake of flexibility. + * For example: + * 1. TimePoint reqStartTimestamp; + * 2. TimePoint firstHeaderByteSentTimestamp; + * 3. TimePoint firstBodyByteTimestamp; + * 3. TimePoint lastBodyByteTimestamp; + * 4. TimePoint lastBodyByteAckTimestamp; + */ + + /* + * time to first header byte written to the kernel send buffer + * NOTE: It is not 100% accurate since TAsyncSocket does not do + * do callback on partial write. + */ + int32_t timeToFirstHeaderByte{-1}; + + /* + * time to first body byte written to the kernel send buffer + */ + int32_t timeToFirstByte{-1}; + + /* + * time to last body byte written to the kernel send buffer + */ + int32_t timeToLastByte{-1}; + + /* + * time to TCP Ack received for the last written body byte + */ + int32_t timeToLastBodyByteAck{-1}; + + /* + * time it took the client to ACK the last byte, from the moment when the + * kernel sent the last byte to the client and until it received the ACK + * for that byte + */ + int32_t lastByteAckLatency{-1}; + + /* + * time spent inside wangle + */ + int32_t proxyLatency{-1}; + + /* + * time between connection accepted and client message headers completed + */ + int32_t clientLatency{-1}; + + /* + * latency for communication with the server + */ + int32_t serverLatency{-1}; + + /* + * time used to get a usable connection. + */ + int32_t connectLatency{-1}; + + /* + * body bytes written + */ + uint32_t egressBodySize{0}; + + /* + * value of errno in case of getsockopt() error + */ + int tcpinfoErrno{0}; + + /* + * bytes read & written during SSL Setup + */ + uint32_t sslSetupBytesWritten{0}; + uint32_t sslSetupBytesRead{0}; + + /** + * SSL error detail + */ + uint32_t sslError{0}; + + /** + * body bytes read + */ + uint32_t ingressBodySize{0}; + + /* + * The SSL version used by the transaction's transport, in + * OpenSSL's format: 4 bits for the major version, followed by 4 bits + * for the minor version. Returns zero for non-SSL. + */ + uint16_t sslVersion{0}; + + /* + * The SSL certificate size. + */ + uint16_t sslCertSize{0}; + + /** + * response status code + */ + uint16_t statusCode{0}; + + /* + * The SSL mode for the transaction's transport: new session, + * resumed session, or neither (non-SSL). + */ + SSLResumeEnum sslResume{SSLResumeEnum::NA}; + + /* + * true if the tcpinfo was successfully read from the kernel + */ + bool validTcpinfo{false}; + + /* + * true if the connection is SSL, false otherwise + */ + bool ssl{false}; + + /* + * get the RTT value in milliseconds + */ + std::chrono::milliseconds getRttMs() const { + return std::chrono::duration_cast(rtt); + } + + /* + * initialize the fields related with tcp_info + */ + bool initWithSocket(const AsyncSocket* sock); + + /* + * Get the kernel's estimate of round-trip time (RTT) to the transport's peer + * in microseconds. Returns -1 on error. + */ + static int64_t readRTT(const AsyncSocket* sock); + +#if defined(__linux__) || defined(__FreeBSD__) + /* + * perform the getsockopt(2) syscall to fetch TCP info for a given socket + */ + static bool readTcpInfo(struct tcp_info* tcpinfo, + const AsyncSocket* sock); +#endif +}; + +} // folly diff --git a/folly/wangle/bootstrap/BootstrapTest.cpp b/folly/wangle/bootstrap/BootstrapTest.cpp new file mode 100644 index 00000000..6b902d20 --- /dev/null +++ b/folly/wangle/bootstrap/BootstrapTest.cpp @@ -0,0 +1,171 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "folly/wangle/bootstrap/ServerBootstrap.h" +#include "folly/wangle/bootstrap/ClientBootstrap.h" +#include "folly/wangle/channel/ChannelHandler.h" + +#include +#include + +using namespace folly::wangle; +using namespace folly; + +typedef ChannelPipeline> Pipeline; + +class TestServer : public ServerBootstrap { + Pipeline* newPipeline(std::shared_ptr) { + return nullptr; + } +}; + +class TestClient : public ClientBootstrap { + Pipeline* newPipeline(std::shared_ptr sock) { + CHECK(sock->good()); + + // We probably aren't connected immedately, check after a small delay + EventBaseManager::get()->getEventBase()->runAfterDelay([sock](){ + CHECK(sock->readable()); + }, 100); + return nullptr; + } +}; + +class TestPipelineFactory : public PipelineFactory { + public: + Pipeline* newPipeline(std::shared_ptr sock) { + pipelines++; + return new Pipeline(); + } + std::atomic pipelines{0}; +}; + +TEST(Bootstrap, Basic) { + TestServer server; + TestClient client; +} + +TEST(Bootstrap, ServerWithPipeline) { + TestServer server; + server.childPipeline(std::make_shared()); + server.bind(0); + server.stop(); +} + +TEST(Bootstrap, ClientServerTest) { + TestServer server; + auto factory = std::make_shared(); + server.childPipeline(factory); + server.bind(0); + auto base = EventBaseManager::get()->getEventBase(); + + SocketAddress address; + server.getSockets()[0]->getAddress(&address); + + TestClient client; + client.connect(address); + base->loop(); + server.stop(); + + CHECK(factory->pipelines == 1); +} + +TEST(Bootstrap, ClientConnectionManagerTest) { + // Create a single IO thread, and verify that + // client connections are pooled properly + + TestServer server; + auto factory = std::make_shared(); + server.childPipeline(factory); + server.group(std::make_shared(1)); + server.bind(0); + auto base = EventBaseManager::get()->getEventBase(); + + SocketAddress address; + server.getSockets()[0]->getAddress(&address); + + TestClient client; + client.connect(address); + + TestClient client2; + client2.connect(address); + + base->loop(); + server.stop(); + + CHECK(factory->pipelines == 2); +} + +TEST(Bootstrap, ServerAcceptGroupTest) { + // Verify that server is using the accept IO group + + TestServer server; + auto factory = std::make_shared(); + server.childPipeline(factory); + server.group(std::make_shared(1), nullptr); + server.bind(0); + + SocketAddress address; + server.getSockets()[0]->getAddress(&address); + + boost::barrier barrier(2); + auto thread = std::thread([&](){ + TestClient client; + client.connect(address); + EventBaseManager::get()->getEventBase()->loop(); + barrier.wait(); + }); + barrier.wait(); + server.stop(); + thread.join(); + + CHECK(factory->pipelines == 1); +} + +TEST(Bootstrap, ServerAcceptGroup2Test) { + // Verify that server is using the accept IO group + + // Check if reuse port is supported, if not, don't run this test + try { + EventBase base; + auto serverSocket = AsyncServerSocket::newSocket(&base); + serverSocket->bind(0); + serverSocket->listen(0); + serverSocket->startAccepting(); + serverSocket->setReusePortEnabled(true); + serverSocket->stopAccepting(); + } catch(...) { + LOG(INFO) << "Reuse port probably not supported"; + return; + } + + TestServer server; + auto factory = std::make_shared(); + server.childPipeline(factory); + server.group(std::make_shared(4), nullptr); + server.bind(0); + + SocketAddress address; + server.getSockets()[0]->getAddress(&address); + + TestClient client; + client.connect(address); + EventBaseManager::get()->getEventBase()->loop(); + + server.stop(); + + CHECK(factory->pipelines == 1); +} diff --git a/folly/wangle/bootstrap/ClientBootstrap.h b/folly/wangle/bootstrap/ClientBootstrap.h new file mode 100644 index 00000000..8ee8fad9 --- /dev/null +++ b/folly/wangle/bootstrap/ClientBootstrap.h @@ -0,0 +1,54 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace folly { + +/* + * A thin wrapper around ChannelPipeline and AsyncSocket to match + * ServerBootstrap. On connect() a new pipeline is created. + */ +template +class ClientBootstrap { + public: + ClientBootstrap() { + } + ClientBootstrap* bind(int port) { + port_ = port; + return this; + } + ClientBootstrap* connect(SocketAddress address) { + pipeline_.reset( + newPipeline( + AsyncSocket::newSocket(EventBaseManager::get()->getEventBase(), address) + )); + return this; + } + + virtual ~ClientBootstrap() {} + + protected: + std::unique_ptr pipeline_; + + int port_; + + virtual Pipeline* newPipeline(std::shared_ptr socket) = 0; +}; + +} // namespace diff --git a/folly/wangle/bootstrap/ServerBootstrap-inl.h b/folly/wangle/bootstrap/ServerBootstrap-inl.h new file mode 100644 index 00000000..a224a9f9 --- /dev/null +++ b/folly/wangle/bootstrap/ServerBootstrap-inl.h @@ -0,0 +1,134 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace folly { + +template +class ServerAcceptor : public Acceptor { + typedef std::unique_ptr PipelinePtr; + + class ServerConnection : public wangle::ManagedConnection { + public: + explicit ServerConnection(PipelinePtr pipeline) + : pipeline_(std::move(pipeline)) {} + + ~ServerConnection() { + } + + void timeoutExpired() noexcept { + } + + void describe(std::ostream& os) const {} + bool isBusy() const { + return false; + } + void notifyPendingShutdown() {} + void closeWhenIdle() {} + void dropConnection() {} + void dumpConnectionState(uint8_t loglevel) {} + private: + PipelinePtr pipeline_; + }; + + public: + explicit ServerAcceptor( + std::shared_ptr> pipelineFactory) + : Acceptor(ServerSocketConfig()) + , pipelineFactory_(pipelineFactory) { + Acceptor::init(nullptr, &base_); + } + + /* See Acceptor::onNewConnection for details */ + void onNewConnection( + AsyncSocket::UniquePtr transport, const SocketAddress* address, + const std::string& nextProtocolName, const TransportInfo& tinfo) { + + std::unique_ptr + pipeline(pipelineFactory_->newPipeline( + std::shared_ptr( + transport.release(), + folly::DelayedDestruction::Destructor()))); + auto connection = new ServerConnection(std::move(pipeline)); + Acceptor::addConnection(connection); + } + + ~ServerAcceptor() { + Acceptor::dropAllConnections(); + } + + private: + EventBase base_; + + std::shared_ptr> pipelineFactory_; +}; + +template +class ServerAcceptorFactory : public AcceptorFactory { + public: + explicit ServerAcceptorFactory( + std::shared_ptr> factory) + : factory_(factory) {} + + std::shared_ptr newAcceptor() { + return std::make_shared>(factory_); + } + private: + std::shared_ptr> factory_; +}; + +class ServerWorkerFactory : public folly::wangle::ThreadFactory { + public: + explicit ServerWorkerFactory(std::shared_ptr acceptorFactory) + : internalFactory_( + std::make_shared("BootstrapWorker")) + , acceptorFactory_(acceptorFactory) + {} + virtual std::thread newThread(folly::Func&& func) override; + + void setInternalFactory( + std::shared_ptr internalFactory); + void setNamePrefix(folly::StringPiece prefix); + + template + void forEachWorker(F&& f); + + private: + std::shared_ptr internalFactory_; + folly::RWSpinLock workersLock_; + std::map> workers_; + int32_t nextWorkerId_{0}; + + std::shared_ptr acceptorFactory_; +}; + +template +void ServerWorkerFactory::forEachWorker(F&& f) { + folly::RWSpinLock::ReadHolder guard(workersLock_); + for (const auto& kv : workers_) { + f(kv.second.get()); + } +} + +} // namespace diff --git a/folly/wangle/bootstrap/ServerBootstrap.cpp b/folly/wangle/bootstrap/ServerBootstrap.cpp new file mode 100644 index 00000000..1f07fadf --- /dev/null +++ b/folly/wangle/bootstrap/ServerBootstrap.cpp @@ -0,0 +1,54 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +namespace folly { + +std::thread ServerWorkerFactory::newThread( + folly::Func&& func) { + auto id = nextWorkerId_++; + auto worker = acceptorFactory_->newAcceptor(); + { + folly::RWSpinLock::WriteHolder guard(workersLock_); + workers_.insert({id, worker}); + } + return internalFactory_->newThread([=](){ + EventBaseManager::get()->setEventBase(worker->getEventBase(), false); + func(); + EventBaseManager::get()->clearEventBase(); + + worker->drainAllConnections(); + { + folly::RWSpinLock::WriteHolder guard(workersLock_); + workers_.erase(id); + } + }); +} + +void ServerWorkerFactory::setInternalFactory( + std::shared_ptr internalFactory) { + CHECK(workers_.empty()); + internalFactory_ = internalFactory; +} + +void ServerWorkerFactory::setNamePrefix(folly::StringPiece prefix) { + CHECK(workers_.empty()); + internalFactory_->setNamePrefix(prefix); +} + +} // namespace diff --git a/folly/wangle/bootstrap/ServerBootstrap.h b/folly/wangle/bootstrap/ServerBootstrap.h new file mode 100644 index 00000000..f77be3fb --- /dev/null +++ b/folly/wangle/bootstrap/ServerBootstrap.h @@ -0,0 +1,238 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace folly { + +/* + * ServerBootstrap is a parent class intended to set up a + * high-performance TCP accepting server. It will manage a pool of + * accepting threads, any number of accepting sockets, a pool of + * IO-worker threads, and connection pool for each IO thread for you. + * + * The output is given as a ChannelPipeline template: given a + * PipelineFactory, it will create a new pipeline for each connection, + * and your server can handle the incoming bytes. + * + * BACKWARDS COMPATIBLITY: for servers already taking a pool of + * Acceptor objects, an AcceptorFactory can be given directly instead + * of a pipeline factory. + */ +template +class ServerBootstrap { + public: + /* TODO(davejwatson) + * + * If there is any work to be done BEFORE handing the work to IO + * threads, this handler is where the pipeline to do it would be + * set. + * + * This could be used for things like logging, load balancing, or + * advanced load balancing on IO threads. Netty also provides this. + */ + ServerBootstrap* handler() { + return this; + } + + /* + * BACKWARDS COMPATIBILITY - an acceptor factory can be set. Your + * Acceptor is responsible for managing the connection pool. + * + * @param childHandler - acceptor factory to call for each IO thread + */ + ServerBootstrap* childHandler(std::shared_ptr childHandler) { + acceptorFactory_ = childHandler; + return this; + } + + /* + * Set a pipeline factory that will be called for each new connection + * + * @param factory pipeline factory to use for each new connection + */ + ServerBootstrap* childPipeline( + std::shared_ptr> factory) { + pipelineFactory_ = factory; + return this; + } + + /* + * Set the IO executor. If not set, a default one will be created + * with one thread per core. + * + * @param io_group - io executor to use for IO threads. + */ + ServerBootstrap* group( + std::shared_ptr io_group) { + return group(nullptr, io_group); + } + + /* + * Set the acceptor executor, and IO executor. + * + * If no acceptor executor is set, a single thread will be created for accepts + * If no IO executor is set, a default of one thread per core will be created + * + * @param group - acceptor executor to use for acceptor threads. + * @param io_group - io executor to use for IO threads. + */ + ServerBootstrap* group( + std::shared_ptr accept_group, + std::shared_ptr io_group) { + if (!accept_group) { + accept_group = std::make_shared( + 1, std::make_shared("Acceptor Thread")); + } + if (!io_group) { + io_group = std::make_shared( + 32, std::make_shared("IO Thread")); + } + auto factoryBase = io_group->getThreadFactory(); + CHECK(factoryBase); + auto factory = std::dynamic_pointer_cast( + factoryBase); + CHECK(factory); // Must be named thread factory + + CHECK(acceptorFactory_ || pipelineFactory_); + + if (acceptorFactory_) { + workerFactory_ = std::make_shared( + acceptorFactory_); + } else { + workerFactory_ = std::make_shared( + std::make_shared>(pipelineFactory_)); + } + workerFactory_->setInternalFactory(factory); + + acceptor_group_ = accept_group; + io_group_ = io_group; + + auto numThreads = io_group_->numThreads(); + io_group_->setNumThreads(0); + io_group_->setThreadFactory(workerFactory_); + io_group_->setNumThreads(numThreads); + + return this; + } + + /* + * Bind to a port and start listening. + * One of childPipeline or childHandler must be called before bind + * + * @param port Port to listen on + */ + void bind(int port) { + // TODO take existing socket + + if (!workerFactory_) { + group(nullptr); + } + + bool reusePort = false; + if (acceptor_group_->numThreads() > 1) { + reusePort = true; + } + + std::mutex sock_lock; + std::vector> new_sockets; + + auto startupFunc = [&](std::shared_ptr barrier){ + auto socket = folly::AsyncServerSocket::newSocket(); + sock_lock.lock(); + new_sockets.push_back(socket); + sock_lock.unlock(); + socket->setReusePortEnabled(reusePort); + socket->attachEventBase(EventBaseManager::get()->getEventBase()); + socket->bind(port); + // TODO Take ServerSocketConfig + socket->listen(1024); + socket->startAccepting(); + + if (port == 0) { + SocketAddress address; + socket->getAddress(&address); + port = address.getPort(); + } + + barrier->wait(); + }; + + auto bind0 = std::make_shared(2); + acceptor_group_->add(std::bind(startupFunc, bind0)); + bind0->wait(); + + auto barrier = std::make_shared(acceptor_group_->numThreads()); + for (int i = 1; i < acceptor_group_->numThreads(); i++) { + acceptor_group_->add(std::bind(startupFunc, barrier)); + } + barrier->wait(); + + // Startup all the threads + for(auto socket : new_sockets) { + workerFactory_->forEachWorker([this, socket](Acceptor* worker){ + socket->getEventBase()->runInEventBaseThread([this, worker, socket](){ + socket->addAcceptCallback(worker, worker->getEventBase()); + }); + }); + } + + for (auto& socket : new_sockets) { + sockets_.push_back(socket); + } + } + + /* + * Stop listening on all sockets. + */ + void stop() { + auto barrier = std::make_shared(sockets_.size() + 1); + for (auto socket : sockets_) { + socket->getEventBase()->runInEventBaseThread([barrier, socket]() { + socket->stopAccepting(); + socket->detachEventBase(); + barrier->wait(); + }); + } + barrier->wait(); + sockets_.clear(); + + acceptor_group_->join(); + io_group_->join(); + } + + /* + * Get the list of listening sockets + */ + std::vector>& + getSockets() { + return sockets_; + } + + private: + std::shared_ptr acceptor_group_; + std::shared_ptr io_group_; + + std::shared_ptr workerFactory_; + std::vector> sockets_; + + std::shared_ptr acceptorFactory_; + std::shared_ptr> pipelineFactory_; +}; + +} // namespace diff --git a/folly/wangle/channel/AsyncSocketHandler.h b/folly/wangle/channel/AsyncSocketHandler.h new file mode 100644 index 00000000..eb47cb05 --- /dev/null +++ b/folly/wangle/channel/AsyncSocketHandler.h @@ -0,0 +1,153 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace folly { namespace wangle { + +class AsyncSocketHandler + : public folly::wangle::BytesToBytesHandler, + public AsyncSocket::ReadCallback { + public: + explicit AsyncSocketHandler( + std::shared_ptr socket) + : socket_(std::move(socket)) {} + + AsyncSocketHandler(AsyncSocketHandler&&) = default; + + ~AsyncSocketHandler() { + if (socket_) { + detachReadCallback(); + } + } + + void attachReadCallback() { + socket_->setReadCB(socket_->good() ? this : nullptr); + } + + void detachReadCallback() { + if (socket_->getReadCallback() == this) { + socket_->setReadCB(nullptr); + } + } + + void attachEventBase(folly::EventBase* eventBase) { + if (eventBase && !socket_->getEventBase()) { + socket_->attachEventBase(eventBase); + } + } + + void detachEventBase() { + detachReadCallback(); + if (socket_->getEventBase()) { + socket_->detachEventBase(); + } + } + + void attachPipeline(Context* ctx) override { + CHECK(!ctx_); + ctx_ = ctx; + } + + folly::wangle::Future write( + Context* ctx, + std::unique_ptr buf) override { + if (UNLIKELY(!buf)) { + return folly::wangle::makeFuture(); + } + + if (!socket_->good()) { + VLOG(5) << "socket is closed in write()"; + return folly::wangle::makeFuture(AsyncSocketException( + AsyncSocketException::AsyncSocketExceptionType::NOT_OPEN, + "socket is closed in write()")); + } + + auto cb = new WriteCallback(); + auto future = cb->promise_.getFuture(); + socket_->writeChain(cb, std::move(buf), ctx->getWriteFlags()); + return future; + }; + + folly::wangle::Future close(Context* ctx) { + if (socket_) { + detachReadCallback(); + socket_->closeNow(); + } + return folly::wangle::makeFuture(); + } + + // Must override to avoid warnings about hidden overloaded virtual due to + // AsyncSocket::ReadCallback::readEOF() + void readEOF(Context* ctx) override { + ctx->fireReadEOF(); + } + + void getReadBuffer(void** bufReturn, size_t* lenReturn) override { + const auto readBufferSettings = ctx_->getReadBufferSettings(); + const auto ret = bufQueue_.preallocate( + readBufferSettings.first, + readBufferSettings.second); + *bufReturn = ret.first; + *lenReturn = ret.second; + } + + void readDataAvailable(size_t len) noexcept override { + bufQueue_.postallocate(len); + ctx_->fireRead(bufQueue_); + } + + void readEOF() noexcept override { + ctx_->fireReadEOF(); + } + + void readErr(const AsyncSocketException& ex) + noexcept override { + ctx_->fireReadException(make_exception_wrapper(ex)); + } + + private: + class WriteCallback : private AsyncSocket::WriteCallback { + void writeSuccess() noexcept override { + promise_.setValue(); + delete this; + } + + void writeErr(size_t bytesWritten, + const AsyncSocketException& ex) + noexcept override { + promise_.setException(ex); + delete this; + } + + private: + friend class AsyncSocketHandler; + folly::wangle::Promise promise_; + }; + + Context* ctx_{nullptr}; + folly::IOBufQueue bufQueue_; + std::shared_ptr socket_{nullptr}; +}; + +}} diff --git a/folly/wangle/channel/ChannelHandler.h b/folly/wangle/channel/ChannelHandler.h new file mode 100644 index 00000000..8e134a3f --- /dev/null +++ b/folly/wangle/channel/ChannelHandler.h @@ -0,0 +1,192 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace folly { namespace wangle { + +template +class ChannelHandler { + public: + typedef Rin rin; + typedef Rout rout; + typedef Win win; + typedef Wout wout; + typedef ChannelHandlerContext Context; + virtual ~ChannelHandler() {} + + virtual void read(Context* ctx, Rin msg) = 0; + virtual void readEOF(Context* ctx) { + ctx->fireReadEOF(); + } + virtual void readException(Context* ctx, exception_wrapper e) { + ctx->fireReadException(std::move(e)); + } + + virtual Future write(Context* ctx, Win msg) = 0; + virtual Future close(Context* ctx) { + return ctx->fireClose(); + } + + virtual void attachPipeline(Context* ctx) {} + virtual void attachTransport(Context* ctx) {} + + virtual void detachPipeline(Context* ctx) {} + virtual void detachTransport(Context* ctx) {} + + /* + // Other sorts of things we might want, all shamelessly stolen from Netty + // inbound + virtual void exceptionCaught( + ChannelHandlerContext* ctx, + exception_wrapper e) {} + virtual void channelRegistered(ChannelHandlerContext* ctx) {} + virtual void channelUnregistered(ChannelHandlerContext* ctx) {} + virtual void channelActive(ChannelHandlerContext* ctx) {} + virtual void channelInactive(ChannelHandlerContext* ctx) {} + virtual void channelReadComplete(ChannelHandlerContext* ctx) {} + virtual void userEventTriggered(ChannelHandlerContext* ctx, void* evt) {} + virtual void channelWritabilityChanged(ChannelHandlerContext* ctx) {} + + // outbound + virtual Future bind( + ChannelHandlerContext* ctx, + SocketAddress localAddress) {} + virtual Future connect( + ChannelHandlerContext* ctx, + SocketAddress remoteAddress, SocketAddress localAddress) {} + virtual Future disconnect(ChannelHandlerContext* ctx) {} + virtual Future deregister(ChannelHandlerContext* ctx) {} + virtual Future read(ChannelHandlerContext* ctx) {} + virtual void flush(ChannelHandlerContext* ctx) {} + */ +}; + +template +class ChannelHandlerAdapter : public ChannelHandler { + public: + typedef typename ChannelHandler::Context Context; + + void read(Context* ctx, R msg) override { + ctx->fireRead(std::forward(msg)); + } + + Future write(Context* ctx, W msg) override { + return ctx->fireWrite(std::forward(msg)); + } +}; + +typedef ChannelHandlerAdapter> +BytesToBytesHandler; + +template +class ChannelHandlerPtr : public ChannelHandler< + typename Handler::rin, + typename Handler::rout, + typename Handler::win, + typename Handler::wout> { + public: + typedef typename std::conditional< + Shared, + std::shared_ptr, + Handler*>::type + HandlerPtr; + + typedef typename Handler::Context Context; + + explicit ChannelHandlerPtr(HandlerPtr handler) + : handler_(std::move(handler)) {} + + void setHandler(HandlerPtr handler) { + if (handler == handler_) { + return; + } + if (handler_ && ctx_) { + handler_->detachPipeline(ctx_); + } + handler_ = std::move(handler); + if (handler_ && ctx_) { + handler_->attachPipeline(ctx_); + if (ctx_->getTransport()) { + handler_->attachTransport(ctx_); + } + } + } + + void attachPipeline(Context* ctx) override { + ctx_ = ctx; + if (handler_) { + handler_->attachPipeline(ctx_); + } + } + + void attachTransport(Context* ctx) override { + ctx_ = ctx; + if (handler_) { + handler_->attachTransport(ctx_); + } + } + + void detachPipeline(Context* ctx) override { + ctx_ = ctx; + if (handler_) { + handler_->detachPipeline(ctx_); + } + } + + void detachTransport(Context* ctx) override { + ctx_ = ctx; + if (handler_) { + handler_->detachTransport(ctx_); + } + } + + void read(Context* ctx, typename Handler::rin msg) override { + DCHECK(handler_); + handler_->read(ctx, std::forward(msg)); + } + + void readEOF(Context* ctx) override { + DCHECK(handler_); + handler_->readEOF(ctx); + } + + void readException(Context* ctx, exception_wrapper e) override { + DCHECK(handler_); + handler_->readException(ctx, std::move(e)); + } + + Future write(Context* ctx, typename Handler::win msg) override { + DCHECK(handler_); + return handler_->write(ctx, std::forward(msg)); + } + + Future close(Context* ctx) override { + DCHECK(handler_); + return handler_->close(ctx); + } + + private: + Context* ctx_; + HandlerPtr handler_; +}; + +}} diff --git a/folly/wangle/channel/ChannelHandlerContext.h b/folly/wangle/channel/ChannelHandlerContext.h new file mode 100644 index 00000000..59ea3ae4 --- /dev/null +++ b/folly/wangle/channel/ChannelHandlerContext.h @@ -0,0 +1,252 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace folly { namespace wangle { + +template +class ChannelHandlerContext { + public: + virtual ~ChannelHandlerContext() {} + + virtual void fireRead(In msg) = 0; + virtual void fireReadEOF() = 0; + virtual void fireReadException(exception_wrapper e) = 0; + + virtual Future fireWrite(Out msg) = 0; + virtual Future fireClose() = 0; + + virtual std::shared_ptr getTransport() = 0; + + virtual void setWriteFlags(WriteFlags flags) = 0; + virtual WriteFlags getWriteFlags() = 0; + + virtual void setReadBufferSettings( + uint64_t minAvailable, + uint64_t allocationSize) = 0; + virtual std::pair getReadBufferSettings() = 0; + + /* TODO + template + virtual void addHandlerBefore(H&&) {} + template + virtual void addHandlerAfter(H&&) {} + template + virtual void replaceHandler(H&&) {} + virtual void removeHandler() {} + */ +}; + +class PipelineContext { + public: + virtual ~PipelineContext() {} + + virtual void attachTransport() = 0; + virtual void detachTransport() = 0; + + void link(PipelineContext* other) { + setNextIn(other); + other->setNextOut(this); + } + + protected: + virtual void setNextIn(PipelineContext* ctx) = 0; + virtual void setNextOut(PipelineContext* ctx) = 0; +}; + +template +class InboundChannelHandlerContext { + public: + virtual ~InboundChannelHandlerContext() {} + virtual void read(In msg) = 0; + virtual void readEOF() = 0; + virtual void readException(exception_wrapper e) = 0; +}; + +template +class OutboundChannelHandlerContext { + public: + virtual ~OutboundChannelHandlerContext() {} + virtual Future write(Out msg) = 0; + virtual Future close() = 0; +}; + +template +class ContextImpl : public ChannelHandlerContext, + public InboundChannelHandlerContext, + public OutboundChannelHandlerContext, + public PipelineContext { + public: + typedef typename H::rin Rin; + typedef typename H::rout Rout; + typedef typename H::win Win; + typedef typename H::wout Wout; + + template + explicit ContextImpl(P* pipeline, HandlerArg&& handlerArg) + : pipeline_(pipeline), + handler_(std::forward(handlerArg)) { + handler_.attachPipeline(this); + } + + ~ContextImpl() { + handler_.detachPipeline(this); + } + + H* getHandler() { + return &handler_; + } + + // PipelineContext overrides + void setNextIn(PipelineContext* ctx) override { + auto nextIn = dynamic_cast*>(ctx); + if (nextIn) { + nextIn_ = nextIn; + } else { + throw std::invalid_argument("wrong type in setNextIn"); + } + } + + void setNextOut(PipelineContext* ctx) override { + auto nextOut = dynamic_cast*>(ctx); + if (nextOut) { + nextOut_ = nextOut; + } else { + throw std::invalid_argument("wrong type in setNextOut"); + } + } + + void attachTransport() override { + typename P::DestructorGuard dg(static_cast(pipeline_)); + handler_.attachTransport(this); + } + + void detachTransport() override { + typename P::DestructorGuard dg(static_cast(pipeline_)); + handler_.detachTransport(this); + } + + // ChannelHandlerContext overrides + void fireRead(Rout msg) override { + typename P::DestructorGuard dg(static_cast(pipeline_)); + if (nextIn_) { + nextIn_->read(std::forward(msg)); + } else { + LOG(WARNING) << "read reached end of pipeline"; + } + } + + void fireReadEOF() override { + typename P::DestructorGuard dg(static_cast(pipeline_)); + if (nextIn_) { + nextIn_->readEOF(); + } else { + LOG(WARNING) << "readEOF reached end of pipeline"; + } + } + + void fireReadException(exception_wrapper e) override { + typename P::DestructorGuard dg(static_cast(pipeline_)); + if (nextIn_) { + nextIn_->readException(std::move(e)); + } else { + LOG(WARNING) << "readException reached end of pipeline"; + } + } + + Future fireWrite(Wout msg) override { + typename P::DestructorGuard dg(static_cast(pipeline_)); + if (nextOut_) { + return nextOut_->write(std::forward(msg)); + } else { + LOG(WARNING) << "write reached end of pipeline"; + return makeFuture(); + } + } + + Future fireClose() override { + typename P::DestructorGuard dg(static_cast(pipeline_)); + if (nextOut_) { + return nextOut_->close(); + } else { + LOG(WARNING) << "close reached end of pipeline"; + return makeFuture(); + } + } + + std::shared_ptr getTransport() override { + return pipeline_->getTransport(); + } + + void setWriteFlags(WriteFlags flags) override { + pipeline_->setWriteFlags(flags); + } + + WriteFlags getWriteFlags() override { + return pipeline_->getWriteFlags(); + } + + void setReadBufferSettings( + uint64_t minAvailable, + uint64_t allocationSize) override { + pipeline_->setReadBufferSettings(minAvailable, allocationSize); + } + + std::pair getReadBufferSettings() override { + return pipeline_->getReadBufferSettings(); + } + + // InboundChannelHandlerContext overrides + void read(Rin msg) override { + typename P::DestructorGuard dg(static_cast(pipeline_)); + handler_.read(this, std::forward(msg)); + } + + void readEOF() override { + typename P::DestructorGuard dg(static_cast(pipeline_)); + handler_.readEOF(this); + } + + void readException(exception_wrapper e) override { + typename P::DestructorGuard dg(static_cast(pipeline_)); + handler_.readException(this, std::move(e)); + } + + // OutboundChannelHandlerContext overrides + Future write(Win msg) override { + typename P::DestructorGuard dg(static_cast(pipeline_)); + return handler_.write(this, std::forward(msg)); + } + + Future close() override { + typename P::DestructorGuard dg(static_cast(pipeline_)); + return handler_.close(this); + } + + private: + P* pipeline_; + H handler_; + InboundChannelHandlerContext* nextIn_{nullptr}; + OutboundChannelHandlerContext* nextOut_{nullptr}; +}; + +}} diff --git a/folly/wangle/channel/ChannelPipeline.h b/folly/wangle/channel/ChannelPipeline.h new file mode 100644 index 00000000..07d10ba8 --- /dev/null +++ b/folly/wangle/channel/ChannelPipeline.h @@ -0,0 +1,342 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace folly { namespace wangle { + +/* + * R is the inbound type, i.e. inbound calls start with pipeline.read(R) + * W is the outbound type, i.e. outbound calls start with pipeline.write(W) + */ +template +class ChannelPipeline; + +template +class ChannelPipeline : public DelayedDestruction { + public: + ChannelPipeline() {} + ~ChannelPipeline() {} + + std::shared_ptr getTransport() { + return transport_; + } + + void setWriteFlags(WriteFlags flags) { + writeFlags_ = flags; + } + + WriteFlags getWriteFlags() { + return writeFlags_; + } + + void setReadBufferSettings(uint64_t minAvailable, uint64_t allocationSize) { + readBufferSettings_ = std::make_pair(minAvailable, allocationSize); + } + + std::pair getReadBufferSettings() { + return readBufferSettings_; + } + + void read(R msg) { + front_->read(std::forward(msg)); + } + + void readEOF() { + front_->readEOF(); + } + + void readException(exception_wrapper e) { + front_->readException(std::move(e)); + } + + Future write(W msg) { + return back_->write(std::forward(msg)); + } + + Future close() { + return back_->close(); + } + + template + ChannelPipeline& addBack(H&& handler) { + ctxs_.push_back(folly::make_unique>( + this, std::forward(handler))); + return *this; + } + + template + ChannelPipeline& addFront(H&& handler) { + ctxs_.insert( + ctxs_.begin(), + folly::make_unique>( + this, + std::forward(handler))); + return *this; + } + + template + H* getHandler(int i) { + auto ctx = dynamic_cast*>(ctxs_[i].get()); + CHECK(ctx); + return ctx->getHandler(); + } + + void finalize() { + finalizeHelper(); + InboundChannelHandlerContext* front; + front_ = dynamic_cast*>( + ctxs_.front().get()); + if (!front_) { + throw std::invalid_argument("wrong type for first handler"); + } + } + + protected: + explicit ChannelPipeline(bool shouldFinalize) { + CHECK(!shouldFinalize); + } + + void finalizeHelper() { + if (ctxs_.empty()) { + return; + } + + for (int i = 0; i < ctxs_.size() - 1; i++) { + ctxs_[i]->link(ctxs_[i+1].get()); + } + + back_ = dynamic_cast*>(ctxs_.back().get()); + if (!back_) { + throw std::invalid_argument("wrong type for last handler"); + } + } + + PipelineContext* getLocalFront() { + return ctxs_.empty() ? nullptr : ctxs_.front().get(); + } + + static const bool is_end{true}; + + std::shared_ptr transport_; + WriteFlags writeFlags_{WriteFlags::NONE}; + std::pair readBufferSettings_{2048, 2048}; + + void attachPipeline() {} + + void attachTransport( + std::shared_ptr transport) { + transport_ = std::move(transport); + } + + void detachTransport() { + transport_ = nullptr; + } + + OutboundChannelHandlerContext* back_{nullptr}; + + private: + InboundChannelHandlerContext* front_{nullptr}; + std::vector> ctxs_; +}; + +template +class ChannelPipeline + : public ChannelPipeline { + protected: + template + ChannelPipeline( + bool shouldFinalize, + HandlerArg&& handlerArg, + HandlersArgs&&... handlersArgs) + : ChannelPipeline( + false, + std::forward(handlersArgs)...), + ctx_(this, std::forward(handlerArg)) { + if (shouldFinalize) { + finalize(); + } + } + + public: + template + explicit ChannelPipeline(HandlersArgs&&... handlersArgs) + : ChannelPipeline(true, std::forward(handlersArgs)...) {} + + ~ChannelPipeline() {} + + void destroy() override { } + + void read(R msg) { + typename ChannelPipeline::DestructorGuard dg( + static_cast(this)); + front_->read(std::forward(msg)); + } + + void readEOF() { + typename ChannelPipeline::DestructorGuard dg( + static_cast(this)); + front_->readEOF(); + } + + void readException(exception_wrapper e) { + typename ChannelPipeline::DestructorGuard dg( + static_cast(this)); + front_->readException(std::move(e)); + } + + Future write(W msg) { + typename ChannelPipeline::DestructorGuard dg( + static_cast(this)); + return back_->write(std::forward(msg)); + } + + Future close() { + typename ChannelPipeline::DestructorGuard dg( + static_cast(this)); + return back_->close(); + } + + void attachTransport( + std::shared_ptr transport) { + typename ChannelPipeline::DestructorGuard dg( + static_cast(this)); + CHECK((!ChannelPipeline::transport_)); + ChannelPipeline::attachTransport(std::move(transport)); + forEachCtx([&](PipelineContext* ctx){ + ctx->attachTransport(); + }); + } + + void detachTransport() { + typename ChannelPipeline::DestructorGuard dg( + static_cast(this)); + ChannelPipeline::detachTransport(); + forEachCtx([&](PipelineContext* ctx){ + ctx->detachTransport(); + }); + } + + std::shared_ptr getTransport() { + return ChannelPipeline::transport_; + } + + template + ChannelPipeline& addBack(H&& handler) { + ChannelPipeline::addBack(std::move(handler)); + return *this; + } + + template + ChannelPipeline& addFront(H&& handler) { + ctxs_.insert( + ctxs_.begin(), + folly::make_unique>( + this, + std::move(handler))); + return *this; + } + + template + H* getHandler(size_t i) { + if (i > ctxs_.size()) { + return ChannelPipeline::template getHandler( + i - (ctxs_.size() + 1)); + } else { + auto pctx = (i == ctxs_.size()) ? &ctx_ : ctxs_[i].get(); + auto ctx = dynamic_cast*>(pctx); + return ctx->getHandler(); + } + } + + void finalize() { + finalizeHelper(); + auto ctx = ctxs_.empty() ? &ctx_ : ctxs_.front().get(); + front_ = dynamic_cast*>(ctx); + if (!front_) { + throw std::invalid_argument("wrong type for first handler"); + } + } + + protected: + void finalizeHelper() { + ChannelPipeline::finalizeHelper(); + back_ = ChannelPipeline::back_; + if (!back_) { + auto is_end = ChannelPipeline::is_end; + CHECK(is_end); + back_ = dynamic_cast*>(&ctx_); + if (!back_) { + throw std::invalid_argument("wrong type for last handler"); + } + } + + if (!ctxs_.empty()) { + for (int i = 0; i < ctxs_.size() - 1; i++) { + ctxs_[i]->link(ctxs_[i+1].get()); + } + ctxs_.back()->link(&ctx_); + } + + auto nextFront = ChannelPipeline::getLocalFront(); + if (nextFront) { + ctx_.link(nextFront); + } + } + + PipelineContext* getLocalFront() { + return ctxs_.empty() ? &ctx_ : ctxs_.front().get(); + } + + static const bool is_end{false}; + InboundChannelHandlerContext* front_{nullptr}; + OutboundChannelHandlerContext* back_{nullptr}; + + private: + template + void forEachCtx(const F& func) { + for (auto& ctx : ctxs_) { + func(ctx.get()); + } + func(&ctx_); + } + + ContextImpl ctx_; + std::vector> ctxs_; +}; + +}} + +namespace folly { + +class AsyncSocket; + +template +class PipelineFactory { + public: + virtual Pipeline* newPipeline(std::shared_ptr) = 0; + virtual ~PipelineFactory() {} +}; + +} diff --git a/folly/wangle/channel/OutputBufferingHandler.h b/folly/wangle/channel/OutputBufferingHandler.h new file mode 100644 index 00000000..06a053d7 --- /dev/null +++ b/folly/wangle/channel/OutputBufferingHandler.h @@ -0,0 +1,79 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace folly { namespace wangle { + +/* + * OutputBufferingHandler buffers writes in order to minimize syscalls. The + * transport will be written to once per event loop instead of on every write. + */ +class OutputBufferingHandler : public BytesToBytesHandler, + protected EventBase::LoopCallback { + public: + Future write(Context* ctx, std::unique_ptr buf) override { + CHECK(buf); + if (!queueSends_) { + return ctx->fireWrite(std::move(buf)); + } else { + ctx_ = ctx; + // Delay sends to optimize for fewer syscalls + if (!sends_) { + DCHECK(!isLoopCallbackScheduled()); + // Buffer all the sends, and call writev once per event loop. + sends_ = std::move(buf); + ctx->getTransport()->getEventBase()->runInLoop(this); + } else { + DCHECK(isLoopCallbackScheduled()); + sends_->prependChain(std::move(buf)); + } + Promise p; + auto f = p.getFuture(); + promises_.push_back(std::move(p)); + return f; + } + } + + void runLoopCallback() noexcept override { + MoveWrapper>> promises(std::move(promises_)); + ctx_->fireWrite(std::move(sends_)).then([promises](Try&& t) mutable { + try { + t.throwIfFailed(); + for (auto& p : *promises) { + p.setValue(); + } + } catch (...) { + for (auto& p : *promises) { + p.setException(std::current_exception()); + } + } + }); + } + + std::vector> promises_; + std::unique_ptr sends_{nullptr}; + bool queueSends_{true}; + Context* ctx_; +}; + +}} diff --git a/folly/wangle/channel/test/ChannelPipelineTest.cpp b/folly/wangle/channel/test/ChannelPipelineTest.cpp new file mode 100644 index 00000000..a058d5b1 --- /dev/null +++ b/folly/wangle/channel/test/ChannelPipelineTest.cpp @@ -0,0 +1,251 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +using namespace folly; +using namespace folly::wangle; +using namespace testing; + +typedef StrictMock> IntHandler; +typedef ChannelHandlerPtr IntHandlerPtr; + +ACTION(FireRead) { + arg0->fireRead(arg1); +} + +ACTION(FireReadEOF) { + arg0->fireReadEOF(); +} + +ACTION(FireReadException) { + arg0->fireReadException(arg1); +} + +ACTION(FireWrite) { + arg0->fireWrite(arg1); +} + +ACTION(FireClose) { + arg0->fireClose(); +} + +// Test move only types, among other things +TEST(ChannelTest, RealHandlersCompile) { + EventBase eb; + auto socket = AsyncSocket::newSocket(&eb); + // static + { + ChannelPipeline, + AsyncSocketHandler, + OutputBufferingHandler> + pipeline{AsyncSocketHandler(socket), OutputBufferingHandler()}; + EXPECT_TRUE(pipeline.getHandler(0)); + EXPECT_TRUE(pipeline.getHandler(1)); + } + // dynamic + { + ChannelPipeline> pipeline; + pipeline + .addBack(AsyncSocketHandler(socket)) + .addBack(OutputBufferingHandler()) + .finalize(); + EXPECT_TRUE(pipeline.getHandler(0)); + EXPECT_TRUE(pipeline.getHandler(1)); + } +} + +// Test that handlers correctly fire the next handler when directed +TEST(ChannelTest, FireActions) { + IntHandler handler1; + IntHandler handler2; + + EXPECT_CALL(handler1, attachPipeline(_)); + EXPECT_CALL(handler2, attachPipeline(_)); + + ChannelPipeline + pipeline(&handler1, &handler2); + + EXPECT_CALL(handler1, read_(_, _)).WillOnce(FireRead()); + EXPECT_CALL(handler2, read_(_, _)).Times(1); + pipeline.read(1); + + EXPECT_CALL(handler1, readEOF(_)).WillOnce(FireReadEOF()); + EXPECT_CALL(handler2, readEOF(_)).Times(1); + pipeline.readEOF(); + + EXPECT_CALL(handler1, readException(_, _)).WillOnce(FireReadException()); + EXPECT_CALL(handler2, readException(_, _)).Times(1); + pipeline.readException(make_exception_wrapper("blah")); + + EXPECT_CALL(handler2, write_(_, _)).WillOnce(FireWrite()); + EXPECT_CALL(handler1, write_(_, _)).Times(1); + EXPECT_NO_THROW(pipeline.write(1).value()); + + EXPECT_CALL(handler2, close_(_)).WillOnce(FireClose()); + EXPECT_CALL(handler1, close_(_)).Times(1); + EXPECT_NO_THROW(pipeline.close().value()); + + EXPECT_CALL(handler1, detachPipeline(_)); + EXPECT_CALL(handler2, detachPipeline(_)); +} + +// Test that nothing bad happens when actions reach the end of the pipeline +// (a warning will be logged, however) +TEST(ChannelTest, ReachEndOfPipeline) { + IntHandler handler; + EXPECT_CALL(handler, attachPipeline(_)); + ChannelPipeline + pipeline(&handler); + + EXPECT_CALL(handler, read_(_, _)).WillOnce(FireRead()); + pipeline.read(1); + + EXPECT_CALL(handler, readEOF(_)).WillOnce(FireReadEOF()); + pipeline.readEOF(); + + EXPECT_CALL(handler, readException(_, _)).WillOnce(FireReadException()); + pipeline.readException(make_exception_wrapper("blah")); + + EXPECT_CALL(handler, write_(_, _)).WillOnce(FireWrite()); + EXPECT_NO_THROW(pipeline.write(1).value()); + + EXPECT_CALL(handler, close_(_)).WillOnce(FireClose()); + EXPECT_NO_THROW(pipeline.close().value()); + + EXPECT_CALL(handler, detachPipeline(_)); +} + +// Test having the last read handler turn around and write +TEST(ChannelTest, TurnAround) { + IntHandler handler1; + IntHandler handler2; + + EXPECT_CALL(handler1, attachPipeline(_)); + EXPECT_CALL(handler2, attachPipeline(_)); + + ChannelPipeline + pipeline(&handler1, &handler2); + + EXPECT_CALL(handler1, read_(_, _)).WillOnce(FireRead()); + EXPECT_CALL(handler2, read_(_, _)).WillOnce(FireWrite()); + EXPECT_CALL(handler1, write_(_, _)).Times(1); + pipeline.read(1); + + EXPECT_CALL(handler1, detachPipeline(_)); + EXPECT_CALL(handler2, detachPipeline(_)); +} + +TEST(ChannelTest, DynamicFireActions) { + IntHandler handler1, handler2, handler3; + EXPECT_CALL(handler2, attachPipeline(_)); + ChannelPipeline + pipeline(&handler2); + + EXPECT_CALL(handler1, attachPipeline(_)); + EXPECT_CALL(handler3, attachPipeline(_)); + + pipeline + .addFront(IntHandlerPtr(&handler1)) + .addBack(IntHandlerPtr(&handler3)) + .finalize(); + + EXPECT_TRUE(pipeline.getHandler(0)); + EXPECT_TRUE(pipeline.getHandler(1)); + EXPECT_TRUE(pipeline.getHandler(2)); + + EXPECT_CALL(handler1, read_(_, _)).WillOnce(FireRead()); + EXPECT_CALL(handler2, read_(_, _)).WillOnce(FireRead()); + EXPECT_CALL(handler3, read_(_, _)).Times(1); + pipeline.read(1); + + EXPECT_CALL(handler3, write_(_, _)).WillOnce(FireWrite()); + EXPECT_CALL(handler2, write_(_, _)).WillOnce(FireWrite()); + EXPECT_CALL(handler1, write_(_, _)).Times(1); + EXPECT_NO_THROW(pipeline.write(1).value()); + + EXPECT_CALL(handler1, detachPipeline(_)); + EXPECT_CALL(handler2, detachPipeline(_)); + EXPECT_CALL(handler3, detachPipeline(_)); +} + +template +class ConcreteChannelHandler : public ChannelHandler { + typedef typename ChannelHandler::Context Context; + public: + void read(Context* ctx, Rin msg) {} + Future write(Context* ctx, Win msg) { return makeFuture(); } +}; + +typedef ChannelHandlerAdapter StringHandler; +typedef ConcreteChannelHandler IntToStringHandler; +typedef ConcreteChannelHandler StringToIntHandler; + +TEST(ChannelPipeline, DynamicConstruction) { + { + ChannelPipeline pipeline; + EXPECT_THROW( + pipeline + .addBack(ChannelHandlerAdapter{}) + .finalize(), std::invalid_argument); + } + { + ChannelPipeline pipeline; + EXPECT_THROW( + pipeline + .addFront(ChannelHandlerAdapter{}) + .finalize(), + std::invalid_argument); + } + { + ChannelPipeline + pipeline{StringHandler(), StringHandler()}; + + // Exercise both addFront and addBack. Final pipeline is + // StI <-> ItS <-> StS <-> StS <-> StI <-> ItS + EXPECT_NO_THROW( + pipeline + .addFront(IntToStringHandler{}) + .addFront(StringToIntHandler{}) + .addBack(StringToIntHandler{}) + .addBack(IntToStringHandler{}) + .finalize()); + } +} + +TEST(ChannelPipeline, AttachTransport) { + IntHandler handler; + EXPECT_CALL(handler, attachPipeline(_)); + ChannelPipeline + pipeline(&handler); + + EventBase eb; + auto socket = AsyncSocket::newSocket(&eb); + + EXPECT_CALL(handler, attachTransport(_)); + pipeline.attachTransport(socket); + + EXPECT_CALL(handler, detachTransport(_)); + pipeline.detachTransport(); + + EXPECT_CALL(handler, detachPipeline(_)); +} diff --git a/folly/wangle/channel/test/MockChannelHandler.h b/folly/wangle/channel/test/MockChannelHandler.h new file mode 100644 index 00000000..0c666d94 --- /dev/null +++ b/folly/wangle/channel/test/MockChannelHandler.h @@ -0,0 +1,64 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace folly { namespace wangle { + +template +class MockChannelHandler : public ChannelHandler { + public: + typedef typename ChannelHandler::Context Context; + + MockChannelHandler() = default; + MockChannelHandler(MockChannelHandler&&) = default; + + MOCK_METHOD2_T(read_, void(Context*, Rin&)); + MOCK_METHOD1_T(readEOF, void(Context*)); + MOCK_METHOD2_T(readException, void(Context*, exception_wrapper)); + + MOCK_METHOD2_T(write_, void(Context*, Win&)); + MOCK_METHOD1_T(close_, void(Context*)); + + MOCK_METHOD1_T(attachPipeline, void(Context*)); + MOCK_METHOD1_T(attachTransport, void(Context*)); + MOCK_METHOD1_T(detachPipeline, void(Context*)); + MOCK_METHOD1_T(detachTransport, void(Context*)); + + void read(Context* ctx, Rin msg) { + read_(ctx, msg); + } + + Future write(Context* ctx, Win msg) override { + return makeFutureTry([&](){ + write_(ctx, msg); + }); + } + + Future close(Context* ctx) override { + return makeFutureTry([&](){ + close_(ctx); + }); + } +}; + +template +using MockChannelHandlerAdapter = MockChannelHandler; + +}} diff --git a/folly/wangle/channel/test/OutputBufferingHandlerTest.cpp b/folly/wangle/channel/test/OutputBufferingHandlerTest.cpp new file mode 100644 index 00000000..04a705c5 --- /dev/null +++ b/folly/wangle/channel/test/OutputBufferingHandlerTest.cpp @@ -0,0 +1,59 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +using namespace folly; +using namespace folly::wangle; +using namespace testing; + +typedef StrictMock>> +MockHandler; + +MATCHER_P(IOBufContains, str, "") { return arg->moveToFbString() == str; } + +TEST(OutputBufferingHandlerTest, Basic) { + MockHandler mockHandler; + EXPECT_CALL(mockHandler, attachPipeline(_)); + ChannelPipeline, + ChannelHandlerPtr, + OutputBufferingHandler> + pipeline(&mockHandler, OutputBufferingHandler{}); + + EventBase eb; + auto socket = AsyncSocket::newSocket(&eb); + EXPECT_CALL(mockHandler, attachTransport(_)); + pipeline.attachTransport(socket); + + // Buffering should prevent writes until the EB loops, and the writes should + // be batched into one write call. + auto f1 = pipeline.write(IOBuf::copyBuffer("hello")); + auto f2 = pipeline.write(IOBuf::copyBuffer("world")); + EXPECT_FALSE(f1.isReady()); + EXPECT_FALSE(f2.isReady()); + EXPECT_CALL(mockHandler, write_(_, IOBufContains("helloworld"))); + eb.loopOnce(); + EXPECT_TRUE(f1.isReady()); + EXPECT_TRUE(f2.isReady()); + EXPECT_CALL(mockHandler, detachPipeline(_)); +} diff --git a/folly/wangle/concurrent/BlockingQueue.h b/folly/wangle/concurrent/BlockingQueue.h new file mode 100644 index 00000000..08a1f703 --- /dev/null +++ b/folly/wangle/concurrent/BlockingQueue.h @@ -0,0 +1,42 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace folly { namespace wangle { + +template +class BlockingQueue { + public: + virtual ~BlockingQueue() {} + virtual void add(T item) = 0; + virtual void addWithPriority(T item, uint32_t priority) { + LOG_FIRST_N(WARNING, 1) << + "add(item, priority) called on a non-priority queue"; + add(std::move(item)); + } + virtual uint32_t getNumPriorities() { + LOG_FIRST_N(WARNING, 1) << + "getNumPriorities() called on a non-priority queue"; + return 1; + } + virtual T take() = 0; + virtual size_t size() = 0; +}; + +}} // folly::wangle diff --git a/folly/wangle/concurrent/CPUThreadPoolExecutor.cpp b/folly/wangle/concurrent/CPUThreadPoolExecutor.cpp new file mode 100644 index 00000000..a03c6151 --- /dev/null +++ b/folly/wangle/concurrent/CPUThreadPoolExecutor.cpp @@ -0,0 +1,137 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace folly { namespace wangle { + +const size_t CPUThreadPoolExecutor::kDefaultMaxQueueSize = 1 << 18; +const size_t CPUThreadPoolExecutor::kDefaultNumPriorities = 2; + +CPUThreadPoolExecutor::CPUThreadPoolExecutor( + size_t numThreads, + std::unique_ptr> taskQueue, + std::shared_ptr threadFactory) + : ThreadPoolExecutor(numThreads, std::move(threadFactory)), + taskQueue_(std::move(taskQueue)) { + addThreads(numThreads); + CHECK(threadList_.get().size() == numThreads); +} + +CPUThreadPoolExecutor::CPUThreadPoolExecutor( + size_t numThreads, + std::shared_ptr threadFactory) + : CPUThreadPoolExecutor( + numThreads, + folly::make_unique>( + CPUThreadPoolExecutor::kDefaultMaxQueueSize), + std::move(threadFactory)) {} + +CPUThreadPoolExecutor::CPUThreadPoolExecutor(size_t numThreads) + : CPUThreadPoolExecutor( + numThreads, + std::make_shared("CPUThreadPool")) {} + +CPUThreadPoolExecutor::CPUThreadPoolExecutor( + size_t numThreads, + uint32_t numPriorities, + std::shared_ptr threadFactory) + : CPUThreadPoolExecutor( + numThreads, + folly::make_unique>( + numPriorities, + CPUThreadPoolExecutor::kDefaultMaxQueueSize), + std::move(threadFactory)) {} + +CPUThreadPoolExecutor::~CPUThreadPoolExecutor() { + stop(); + CHECK(threadsToStop_ == 0); +} + +void CPUThreadPoolExecutor::add(Func func) { + add(std::move(func), std::chrono::milliseconds(0)); +} + +void CPUThreadPoolExecutor::add( + Func func, + std::chrono::milliseconds expiration, + Func expireCallback) { + // TODO handle enqueue failure, here and in other add() callsites + taskQueue_->add( + CPUTask(std::move(func), expiration, std::move(expireCallback))); +} + +void CPUThreadPoolExecutor::add(Func func, uint32_t priority) { + add(std::move(func), priority, std::chrono::milliseconds(0)); +} + +void CPUThreadPoolExecutor::add( + Func func, + uint32_t priority, + std::chrono::milliseconds expiration, + Func expireCallback) { + CHECK(priority < getNumPriorities()); + taskQueue_->addWithPriority( + CPUTask(std::move(func), expiration, std::move(expireCallback)), + priority); +} + +uint32_t CPUThreadPoolExecutor::getNumPriorities() const { + return taskQueue_->getNumPriorities(); +} + +BlockingQueue* +CPUThreadPoolExecutor::getTaskQueue() { + return taskQueue_.get(); +} + +void CPUThreadPoolExecutor::threadRun(std::shared_ptr thread) { + thread->startupBaton.post(); + while (1) { + auto task = taskQueue_->take(); + if (UNLIKELY(task.poison)) { + CHECK(threadsToStop_-- > 0); + stoppedThreads_.add(thread); + return; + } else { + runTask(thread, std::move(task)); + } + + if (UNLIKELY(threadsToStop_ > 0 && !isJoin_)) { + if (--threadsToStop_ >= 0) { + stoppedThreads_.add(thread); + return; + } else { + threadsToStop_++; + } + } + } +} + +void CPUThreadPoolExecutor::stopThreads(size_t n) { + CHECK(stoppedThreads_.size() == 0); + threadsToStop_ = n; + for (size_t i = 0; i < n; i++) { + taskQueue_->add(CPUTask()); + } +} + +uint64_t CPUThreadPoolExecutor::getPendingTaskCount() { + return taskQueue_->size(); +} + +}} // folly::wangle diff --git a/folly/wangle/concurrent/CPUThreadPoolExecutor.h b/folly/wangle/concurrent/CPUThreadPoolExecutor.h new file mode 100644 index 00000000..bc612ae6 --- /dev/null +++ b/folly/wangle/concurrent/CPUThreadPoolExecutor.h @@ -0,0 +1,94 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace folly { namespace wangle { + +class CPUThreadPoolExecutor : public ThreadPoolExecutor { + public: + struct CPUTask; + + explicit CPUThreadPoolExecutor( + size_t numThreads, + std::unique_ptr> taskQueue, + std::shared_ptr threadFactory = + std::make_shared("CPUThreadPool")); + + explicit CPUThreadPoolExecutor(size_t numThreads); + + explicit CPUThreadPoolExecutor( + size_t numThreads, + std::shared_ptr threadFactory); + + explicit CPUThreadPoolExecutor( + size_t numThreads, + uint32_t numPriorities, + std::shared_ptr threadFactory = + std::make_shared("CPUThreadPool")); + + ~CPUThreadPoolExecutor(); + + void add(Func func) override; + void add( + Func func, + std::chrono::milliseconds expiration, + Func expireCallback = nullptr) override; + + void add(Func func, uint32_t priority); + void add( + Func func, + uint32_t priority, + std::chrono::milliseconds expiration, + Func expireCallback = nullptr); + + uint32_t getNumPriorities() const; + + struct CPUTask : public ThreadPoolExecutor::Task { + // Must be noexcept move constructible so it can be used in MPMCQueue + explicit CPUTask( + Func&& f, + std::chrono::milliseconds expiration, + Func&& expireCallback) + : Task(std::move(f), expiration, std::move(expireCallback)), + poison(false) {} + CPUTask() + : Task(nullptr, std::chrono::milliseconds(0), nullptr), + poison(true) {} + CPUTask(CPUTask&& o) noexcept : Task(std::move(o)), poison(o.poison) {} + CPUTask(const CPUTask&) = default; + CPUTask& operator=(const CPUTask&) = default; + bool poison; + }; + + static const size_t kDefaultMaxQueueSize; + static const size_t kDefaultNumPriorities; + + protected: + BlockingQueue* getTaskQueue(); + + private: + void threadRun(ThreadPtr thread) override; + void stopThreads(size_t n) override; + uint64_t getPendingTaskCount() override; + + std::unique_ptr> taskQueue_; + std::atomic threadsToStop_{0}; +}; + +}} // folly::wangle diff --git a/folly/wangle/concurrent/Codel.cpp b/folly/wangle/concurrent/Codel.cpp new file mode 100644 index 00000000..b2b8cc7d --- /dev/null +++ b/folly/wangle/concurrent/Codel.cpp @@ -0,0 +1,91 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#ifndef NO_LIB_GFLAGS + #include + DEFINE_int32(codel_interval, 100, + "Codel default interval time in ms"); + DEFINE_int32(codel_target_delay, 5, + "Target codel queueing delay in ms"); +#endif + +namespace folly { namespace wangle { + +#ifdef NO_LIB_GFLAGS + int32_t FLAGS_codel_interval = 100; + int32_t FLAGS_codel_target_delay = 5; +#endif + +Codel::Codel() + : codelMinDelay_(0), + codelIntervalTime_(std::chrono::steady_clock::now()), + codelResetDelay_(true), + overloaded_(false) {} + +bool Codel::overloaded(std::chrono::microseconds delay) { + bool ret = false; + auto now = std::chrono::steady_clock::now(); + + // Avoid another thread updating the value at the same time we are using it + // to calculate the overloaded state + auto minDelay = codelMinDelay_; + + if (now > codelIntervalTime_ && + (!codelResetDelay_.load(std::memory_order_acquire) + && !codelResetDelay_.exchange(true))) { + codelIntervalTime_ = now + std::chrono::milliseconds(FLAGS_codel_interval); + + if (minDelay > std::chrono::milliseconds(FLAGS_codel_target_delay)) { + overloaded_ = true; + } else { + overloaded_ = false; + } + } + // Care must be taken that only a single thread resets codelMinDelay_, + // and that it happens after the interval reset above + if (codelResetDelay_.load(std::memory_order_acquire) && + codelResetDelay_.exchange(false)) { + codelMinDelay_ = delay; + // More than one request must come in during an interval before codel + // starts dropping requests + return false; + } else if(delay < codelMinDelay_) { + codelMinDelay_ = delay; + } + + if (overloaded_ && + delay > std::chrono::milliseconds(FLAGS_codel_target_delay * 2)) { + ret = true; + } + + return ret; + +} + +int Codel::getLoad() { + return std::min(100, (int)codelMinDelay_.count() / + (2 * FLAGS_codel_target_delay)); +} + +int Codel::getMinDelay() { + return (int) codelMinDelay_.count(); +} + +}} //namespace diff --git a/folly/wangle/concurrent/Codel.h b/folly/wangle/concurrent/Codel.h new file mode 100644 index 00000000..16f0205b --- /dev/null +++ b/folly/wangle/concurrent/Codel.h @@ -0,0 +1,66 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace folly { namespace wangle { + +/* Codel algorithm implementation: + * http://en.wikipedia.org/wiki/CoDel + * + * Algorithm modified slightly: Instead of changing the interval time + * based on the average min delay, instead we use an alternate timeout + * for each task if the min delay during the interval period is too + * high. + * + * This was found to have better latency metrics than changing the + * window size, since we can communicate with the sender via thrift + * instead of only via the tcp window size congestion control, as in TCP. + */ +class Codel { + + public: + Codel(); + + // Given a delay, returns wether the codel algorithm would + // reject a queued request with this delay. + // + // Internally, it also keeps track of the interval + bool overloaded(std::chrono::microseconds delay); + + // Get the queue load, as seen by the codel algorithm + // Gives a rough guess at how bad the queue delay is. + // + // Return: 0 = no delay, 100 = At the queueing limit + int getLoad(); + + int getMinDelay(); + + private: + std::chrono::microseconds codelMinDelay_; + std::chrono::time_point codelIntervalTime_; + + // flag to make overloaded() thread-safe, since we only want + // to reset the delay once per time period + std::atomic codelResetDelay_; + + bool overloaded_; +}; + +}} // Namespace diff --git a/folly/wangle/concurrent/FutureExecutor.h b/folly/wangle/concurrent/FutureExecutor.h new file mode 100644 index 00000000..8aeedff8 --- /dev/null +++ b/folly/wangle/concurrent/FutureExecutor.h @@ -0,0 +1,79 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace folly { namespace wangle { + +template +class FutureExecutor : public ExecutorImpl { + public: + template + explicit FutureExecutor(Args&&... args) + : ExecutorImpl(std::forward(args)...) {} + + /* + * Given a function func that returns a Future, adds that function to the + * contained Executor and returns a Future which will be fulfilled with + * func's result once it has been executed. + * + * For example: auto f = futureExecutor.addFuture([](){ + * return doAsyncWorkAndReturnAFuture(); + * }); + */ + template + typename std::enable_if::type>::value, + typename std::result_of::type>::type + addFuture(F func) { + typedef typename std::result_of::type::value_type T; + Promise promise; + auto future = promise.getFuture(); + auto movePromise = folly::makeMoveWrapper(std::move(promise)); + auto moveFunc = folly::makeMoveWrapper(std::move(func)); + ExecutorImpl::add([movePromise, moveFunc] () mutable { + (*moveFunc)().then([movePromise] (Try&& t) mutable { + movePromise->fulfilTry(std::move(t)); + }); + }); + return future; + } + + /* + * Similar to addFuture above, but takes a func that returns some non-Future + * type T. + * + * For example: auto f = futureExecutor.addFuture([]() { + * return 42; + * }); + */ + template + typename std::enable_if::type>::value, + Future::type>>::type + addFuture(F func) { + typedef typename std::result_of::type T; + Promise promise; + auto future = promise.getFuture(); + auto movePromise = folly::makeMoveWrapper(std::move(promise)); + auto moveFunc = folly::makeMoveWrapper(std::move(func)); + ExecutorImpl::add([movePromise, moveFunc] () mutable { + movePromise->fulfil(std::move(*moveFunc)); + }); + return future; + } +}; + +}} diff --git a/folly/wangle/concurrent/GlobalExecutor.cpp b/folly/wangle/concurrent/GlobalExecutor.cpp new file mode 100644 index 00000000..e8ac292e --- /dev/null +++ b/folly/wangle/concurrent/GlobalExecutor.cpp @@ -0,0 +1,55 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +using namespace folly; +using namespace folly::wangle; + +namespace { + +Singleton globalIOThreadPoolSingleton( + "GlobalIOThreadPool", + [](){ + return new IOThreadPoolExecutor( + sysconf(_SC_NPROCESSORS_ONLN), + std::make_shared("GlobalIOThreadPool")); + }); + +} + +namespace folly { namespace wangle { + +IOExecutor* getIOExecutor() { + auto singleton = IOExecutor::getSingleton(); + auto executor = singleton->load(); + while (!executor) { + IOExecutor* nullIOExecutor = nullptr; + singleton->compare_exchange_strong( + nullIOExecutor, + Singleton::get("GlobalIOThreadPool")); + executor = singleton->load(); + } + return executor; +} + +void setIOExecutor(IOExecutor* executor) { + IOExecutor::getSingleton()->store(executor); +} + +}} // folly::wangle diff --git a/folly/wangle/concurrent/GlobalExecutor.h b/folly/wangle/concurrent/GlobalExecutor.h new file mode 100644 index 00000000..cac76be8 --- /dev/null +++ b/folly/wangle/concurrent/GlobalExecutor.h @@ -0,0 +1,32 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace folly { namespace wangle { + +class IOExecutor; + +// Retrieve the global IOExecutor. If there is none, a default +// IOThreadPoolExecutor will be constructed and returned. +IOExecutor* getIOExecutor(); + +// Set an IOExecutor to be the global IOExecutor which will be returned by +// subsequent calls to getIOExecutor(). IOExecutors will uninstall themselves +// as global when they are destructed. +void setIOExecutor(IOExecutor* executor); + +}} diff --git a/folly/wangle/concurrent/IOExecutor.cpp b/folly/wangle/concurrent/IOExecutor.cpp new file mode 100644 index 00000000..c5c3e796 --- /dev/null +++ b/folly/wangle/concurrent/IOExecutor.cpp @@ -0,0 +1,50 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include + +using folly::Singleton; +using folly::wangle::IOExecutor; + +namespace { + +Singleton> globalIOExecutorSingleton( + "GlobalIOExecutor", + [](){ + return new std::atomic(nullptr); + }); + +} + +namespace folly { namespace wangle { + +IOExecutor::~IOExecutor() { + auto thisCopy = this; + try { + getSingleton()->compare_exchange_strong(thisCopy, nullptr); + } catch (const std::runtime_error& e) { + // The global IOExecutor singleton was already destructed so doesn't need to + // be restored. Ignore. + } +} + +std::atomic* IOExecutor::getSingleton() { + return Singleton>::get("GlobalIOExecutor"); +} + +}} // folly::wangle diff --git a/folly/wangle/concurrent/IOExecutor.h b/folly/wangle/concurrent/IOExecutor.h new file mode 100644 index 00000000..14eb6643 --- /dev/null +++ b/folly/wangle/concurrent/IOExecutor.h @@ -0,0 +1,52 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace folly { +class EventBase; +} + +namespace folly { namespace wangle { + +// An IOExecutor is an executor that operates on at least one EventBase. One of +// these EventBases should be accessible via getEventBase(). The event base +// returned by a call to getEventBase() is implementation dependent. +// +// Note that IOExecutors don't necessarily loop on the base themselves - for +// instance, EventBase itself is an IOExecutor but doesn't drive itself. +// +// Implementations of IOExecutor are eligible to become the global IO executor, +// returned on every call to getIOExecutor(), via setIOExecutor(). +// These functions are declared in GlobalExecutor.h +// +// If getIOExecutor is called and none has been set, a default global +// IOThreadPoolExecutor will be created and returned. +class IOExecutor : public virtual Executor { + public: + virtual ~IOExecutor(); + virtual EventBase* getEventBase() = 0; + + private: + static std::atomic* getSingleton(); + friend IOExecutor* getIOExecutor(); + friend void setIOExecutor(IOExecutor* executor); +}; + +}} diff --git a/folly/wangle/concurrent/IOThreadPoolExecutor.cpp b/folly/wangle/concurrent/IOThreadPoolExecutor.cpp new file mode 100644 index 00000000..5c97bf44 --- /dev/null +++ b/folly/wangle/concurrent/IOThreadPoolExecutor.cpp @@ -0,0 +1,180 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include + +#include + +namespace folly { namespace wangle { + +using folly::detail::MemoryIdler; + +/* Class that will free jemalloc caches and madvise the stack away + * if the event loop is unused for some period of time + */ +class MemoryIdlerTimeout + : public AsyncTimeout , public EventBase::LoopCallback { + public: + explicit MemoryIdlerTimeout(EventBase* b) : AsyncTimeout(b), base_(b) {} + + virtual void timeoutExpired() noexcept { + idled = true; + } + + virtual void runLoopCallback() noexcept { + if (idled) { + MemoryIdler::flushLocalMallocCaches(); + MemoryIdler::unmapUnusedStack(MemoryIdler::kDefaultStackToRetain); + + idled = false; + } else { + std::chrono::steady_clock::duration idleTimeout = + MemoryIdler::defaultIdleTimeout.load( + std::memory_order_acquire); + + idleTimeout = MemoryIdler::getVariationTimeout(idleTimeout); + + scheduleTimeout(std::chrono::duration_cast( + idleTimeout).count()); + } + + // reschedule this callback for the next event loop. + base_->runBeforeLoop(this); + } + private: + EventBase* base_; + bool idled{false}; +} ; + +IOThreadPoolExecutor::IOThreadPoolExecutor( + size_t numThreads, + std::shared_ptr threadFactory) + : ThreadPoolExecutor(numThreads, std::move(threadFactory)), + nextThread_(0) { + addThreads(numThreads); + CHECK(threadList_.get().size() == numThreads); +} + +IOThreadPoolExecutor::~IOThreadPoolExecutor() { + stop(); +} + +void IOThreadPoolExecutor::add(Func func) { + add(std::move(func), std::chrono::milliseconds(0)); +} + +void IOThreadPoolExecutor::add( + Func func, + std::chrono::milliseconds expiration, + Func expireCallback) { + RWSpinLock::ReadHolder{&threadListLock_}; + if (threadList_.get().empty()) { + throw std::runtime_error("No threads available"); + } + auto ioThread = pickThread(); + + auto moveTask = folly::makeMoveWrapper( + Task(std::move(func), expiration, std::move(expireCallback))); + auto wrappedFunc = [ioThread, moveTask] () mutable { + runTask(ioThread, std::move(*moveTask)); + ioThread->pendingTasks--; + }; + + ioThread->pendingTasks++; + if (!ioThread->eventBase->runInEventBaseThread(std::move(wrappedFunc))) { + ioThread->pendingTasks--; + throw std::runtime_error("Unable to run func in event base thread"); + } +} + +std::shared_ptr +IOThreadPoolExecutor::pickThread() { + if (*thisThread_) { + return *thisThread_; + } + auto thread = threadList_.get()[nextThread_++ % threadList_.get().size()]; + return std::static_pointer_cast(thread); +} + +EventBase* IOThreadPoolExecutor::getEventBase() { + return pickThread()->eventBase; +} + +std::shared_ptr +IOThreadPoolExecutor::makeThread() { + return std::make_shared(this); +} + +void IOThreadPoolExecutor::threadRun(ThreadPtr thread) { + const auto ioThread = std::static_pointer_cast(thread); + ioThread->eventBase = + folly::EventBaseManager::get()->getEventBase(); + thisThread_.reset(new std::shared_ptr(ioThread)); + + auto idler = new MemoryIdlerTimeout(ioThread->eventBase); + ioThread->eventBase->runBeforeLoop(idler); + + thread->startupBaton.post(); + while (ioThread->shouldRun) { + ioThread->eventBase->loopForever(); + } + if (isJoin_) { + while (ioThread->pendingTasks > 0) { + ioThread->eventBase->loopOnce(); + } + } + stoppedThreads_.add(ioThread); +} + +// threadListLock_ is writelocked +void IOThreadPoolExecutor::stopThreads(size_t n) { + for (size_t i = 0; i < n; i++) { + const auto ioThread = std::static_pointer_cast( + threadList_.get()[i]); + ioThread->shouldRun = false; + ioThread->eventBase->terminateLoopSoon(); + } +} + +std::vector IOThreadPoolExecutor::getEventBases() { + std::vector bases; + RWSpinLock::ReadHolder{&threadListLock_}; + for (const auto& thread : threadList_.get()) { + auto ioThread = std::static_pointer_cast(thread); + bases.push_back(ioThread->eventBase); + } + return bases; +} + +// threadListLock_ is readlocked +uint64_t IOThreadPoolExecutor::getPendingTaskCount() { + uint64_t count = 0; + for (const auto& thread : threadList_.get()) { + auto ioThread = std::static_pointer_cast(thread); + size_t pendingTasks = ioThread->pendingTasks; + if (pendingTasks > 0 && !ioThread->idle) { + pendingTasks--; + } + count += pendingTasks; + } + return count; +} + +}} // folly::wangle diff --git a/folly/wangle/concurrent/IOThreadPoolExecutor.h b/folly/wangle/concurrent/IOThreadPoolExecutor.h new file mode 100644 index 00000000..7c919d1e --- /dev/null +++ b/folly/wangle/concurrent/IOThreadPoolExecutor.h @@ -0,0 +1,67 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace folly { namespace wangle { + +// N.B. For this thread pool, stop() behaves like join() because outstanding +// tasks belong to the event base and will be executed upon its destruction. +class IOThreadPoolExecutor : public ThreadPoolExecutor, public IOExecutor { + public: + explicit IOThreadPoolExecutor( + size_t numThreads, + std::shared_ptr threadFactory = + std::make_shared("IOThreadPool")); + + ~IOThreadPoolExecutor(); + + void add(Func func) override; + void add( + Func func, + std::chrono::milliseconds expiration, + Func expireCallback = nullptr) override; + + EventBase* getEventBase() override; + + std::vector getEventBases(); + + private: + struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING IOThread : public Thread { + IOThread(IOThreadPoolExecutor* pool) + : Thread(pool), + shouldRun(true), + pendingTasks(0) {}; + std::atomic shouldRun; + std::atomic pendingTasks; + EventBase* eventBase; + }; + + ThreadPtr makeThread() override; + std::shared_ptr pickThread(); + void threadRun(ThreadPtr thread) override; + void stopThreads(size_t n) override; + uint64_t getPendingTaskCount() override; + + size_t nextThread_; + ThreadLocal> thisThread_; +}; + +}} // folly::wangle diff --git a/folly/wangle/concurrent/LifoSemMPMCQueue.h b/folly/wangle/concurrent/LifoSemMPMCQueue.h new file mode 100644 index 00000000..71c2cafe --- /dev/null +++ b/folly/wangle/concurrent/LifoSemMPMCQueue.h @@ -0,0 +1,57 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +namespace folly { namespace wangle { + +template +class LifoSemMPMCQueue : public BlockingQueue { + public: + explicit LifoSemMPMCQueue(size_t max_capacity) : queue_(max_capacity) {} + + void add(T item) override { + if (!queue_.write(std::move(item))) { + throw std::runtime_error("LifoSemMPMCQueue full, can't add item"); + } + sem_.post(); + } + + T take() override { + T item; + while (!queue_.read(item)) { + sem_.wait(); + } + return item; + } + + size_t capacity() { + return queue_.capacity(); + } + + size_t size() override { + return queue_.size(); + } + + private: + LifoSem sem_; + MPMCQueue queue_; +}; + +}} // folly::wangle diff --git a/folly/wangle/concurrent/NamedThreadFactory.h b/folly/wangle/concurrent/NamedThreadFactory.h new file mode 100644 index 00000000..2e2d27de --- /dev/null +++ b/folly/wangle/concurrent/NamedThreadFactory.h @@ -0,0 +1,56 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +namespace folly { namespace wangle { + +class NamedThreadFactory : public ThreadFactory { + public: + explicit NamedThreadFactory(folly::StringPiece prefix) + : prefix_(prefix.str()), suffix_(0) {} + + std::thread newThread(Func&& func) override { + auto thread = std::thread(std::move(func)); + folly::setThreadName( + thread.native_handle(), + folly::to(prefix_, suffix_++)); + return thread; + } + + void setNamePrefix(folly::StringPiece prefix) { + prefix_ = prefix.str(); + } + + std::string getNamePrefix() { + return prefix_; + } + + private: + std::string prefix_; + std::atomic suffix_; +}; + +}} // folly::wangle diff --git a/folly/wangle/concurrent/PriorityLifoSemMPMCQueue.h b/folly/wangle/concurrent/PriorityLifoSemMPMCQueue.h new file mode 100644 index 00000000..6ee62bfb --- /dev/null +++ b/folly/wangle/concurrent/PriorityLifoSemMPMCQueue.h @@ -0,0 +1,77 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +namespace folly { namespace wangle { + +template +class PriorityLifoSemMPMCQueue : public BlockingQueue { + public: + explicit PriorityLifoSemMPMCQueue(uint32_t numPriorities, size_t capacity) { + CHECK(numPriorities > 0); + queues_.reserve(numPriorities); + for (int i = 0; i < numPriorities; i++) { + queues_.push_back(MPMCQueue(capacity)); + } + } + + uint32_t getNumPriorities() override { + return queues_.size(); + } + + // Add at lowest priority by default + void add(T item) override { + addWithPriority(std::move(item), 0); + } + + void addWithPriority(T item, uint32_t priority) override { + CHECK(priority < queues_.size()); + if (!queues_[priority].write(std::move(item))) { + throw std::runtime_error("LifoSemMPMCQueue full, can't add item"); + } + sem_.post(); + } + + T take() override { + T item; + while (true) { + for (auto it = queues_.rbegin(); it != queues_.rend(); it++) { + if (it->read(item)) { + return item; + } + } + sem_.wait(); + } + } + + size_t size() override { + size_t size = 0; + for (auto& q : queues_) { + size += q.size(); + } + return size; + } + + private: + LifoSem sem_; + std::vector> queues_; +}; + +}} // folly::wangle diff --git a/folly/wangle/concurrent/ThreadFactory.h b/folly/wangle/concurrent/ThreadFactory.h new file mode 100644 index 00000000..7654fbc9 --- /dev/null +++ b/folly/wangle/concurrent/ThreadFactory.h @@ -0,0 +1,30 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +#include + +namespace folly { namespace wangle { + +class ThreadFactory { + public: + virtual ~ThreadFactory() {} + virtual std::thread newThread(Func&& func) = 0; +}; + +}} // folly::wangle diff --git a/folly/wangle/concurrent/ThreadPoolExecutor.cpp b/folly/wangle/concurrent/ThreadPoolExecutor.cpp new file mode 100644 index 00000000..40694754 --- /dev/null +++ b/folly/wangle/concurrent/ThreadPoolExecutor.cpp @@ -0,0 +1,174 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace folly { namespace wangle { + +ThreadPoolExecutor::ThreadPoolExecutor( + size_t numThreads, + std::shared_ptr threadFactory) + : threadFactory_(std::move(threadFactory)), + taskStatsSubject_(std::make_shared>()) {} + +ThreadPoolExecutor::~ThreadPoolExecutor() { + CHECK(threadList_.get().size() == 0); +} + +ThreadPoolExecutor::Task::Task( + Func&& func, + std::chrono::milliseconds expiration, + Func&& expireCallback) + : func_(std::move(func)), + expiration_(expiration), + expireCallback_(std::move(expireCallback)) { + // Assume that the task in enqueued on creation + enqueueTime_ = std::chrono::steady_clock::now(); +} + +void ThreadPoolExecutor::runTask( + const ThreadPtr& thread, + Task&& task) { + thread->idle = false; + auto startTime = std::chrono::steady_clock::now(); + task.stats_.waitTime = startTime - task.enqueueTime_; + if (task.expiration_ > std::chrono::milliseconds(0) && + task.stats_.waitTime >= task.expiration_) { + task.stats_.expired = true; + if (task.expireCallback_ != nullptr) { + task.expireCallback_(); + } + } else { + try { + task.func_(); + } catch (const std::exception& e) { + LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled " << + typeid(e).name() << " exception: " << e.what(); + } catch (...) { + LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled non-exception " + "object"; + } + task.stats_.runTime = std::chrono::steady_clock::now() - startTime; + } + thread->idle = true; + thread->taskStatsSubject->onNext(std::move(task.stats_)); +} + +size_t ThreadPoolExecutor::numThreads() { + RWSpinLock::ReadHolder{&threadListLock_}; + return threadList_.get().size(); +} + +void ThreadPoolExecutor::setNumThreads(size_t n) { + RWSpinLock::WriteHolder{&threadListLock_}; + const auto current = threadList_.get().size(); + if (n > current ) { + addThreads(n - current); + } else if (n < current) { + removeThreads(current - n, true); + } + CHECK(threadList_.get().size() == n); +} + +// threadListLock_ is writelocked +void ThreadPoolExecutor::addThreads(size_t n) { + std::vector newThreads; + for (size_t i = 0; i < n; i++) { + newThreads.push_back(makeThread()); + } + for (auto& thread : newThreads) { + // TODO need a notion of failing to create the thread + // and then handling for that case + thread->handle = threadFactory_->newThread( + std::bind(&ThreadPoolExecutor::threadRun, this, thread)); + threadList_.add(thread); + } + for (auto& thread : newThreads) { + thread->startupBaton.wait(); + } +} + +// threadListLock_ is writelocked +void ThreadPoolExecutor::removeThreads(size_t n, bool isJoin) { + CHECK(n <= threadList_.get().size()); + CHECK(stoppedThreads_.size() == 0); + isJoin_ = isJoin; + stopThreads(n); + for (size_t i = 0; i < n; i++) { + auto thread = stoppedThreads_.take(); + thread->handle.join(); + threadList_.remove(thread); + } + CHECK(stoppedThreads_.size() == 0); +} + +void ThreadPoolExecutor::stop() { + RWSpinLock::WriteHolder{&threadListLock_}; + removeThreads(threadList_.get().size(), false); + CHECK(threadList_.get().size() == 0); +} + +void ThreadPoolExecutor::join() { + RWSpinLock::WriteHolder{&threadListLock_}; + removeThreads(threadList_.get().size(), true); + CHECK(threadList_.get().size() == 0); +} + +ThreadPoolExecutor::PoolStats ThreadPoolExecutor::getPoolStats() { + RWSpinLock::ReadHolder{&threadListLock_}; + ThreadPoolExecutor::PoolStats stats; + stats.threadCount = threadList_.get().size(); + for (auto thread : threadList_.get()) { + if (thread->idle) { + stats.idleThreadCount++; + } else { + stats.activeThreadCount++; + } + } + stats.pendingTaskCount = getPendingTaskCount(); + stats.totalTaskCount = stats.pendingTaskCount + stats.activeThreadCount; + return stats; +} + +std::atomic ThreadPoolExecutor::Thread::nextId(0); + +void ThreadPoolExecutor::StoppedThreadQueue::add( + ThreadPoolExecutor::ThreadPtr item) { + std::lock_guard guard(mutex_); + queue_.push(std::move(item)); + sem_.post(); +} + +ThreadPoolExecutor::ThreadPtr ThreadPoolExecutor::StoppedThreadQueue::take() { + while(1) { + { + std::lock_guard guard(mutex_); + if (queue_.size() > 0) { + auto item = std::move(queue_.front()); + queue_.pop(); + return item; + } + } + sem_.wait(); + } +} + +size_t ThreadPoolExecutor::StoppedThreadQueue::size() { + std::lock_guard guard(mutex_); + return queue_.size(); +} + +}} // folly::wangle diff --git a/folly/wangle/concurrent/ThreadPoolExecutor.h b/folly/wangle/concurrent/ThreadPoolExecutor.h new file mode 100644 index 00000000..be8f7968 --- /dev/null +++ b/folly/wangle/concurrent/ThreadPoolExecutor.h @@ -0,0 +1,190 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +namespace folly { namespace wangle { + +class ThreadPoolExecutor : public virtual Executor { + public: + explicit ThreadPoolExecutor( + size_t numThreads, + std::shared_ptr threadFactory); + + ~ThreadPoolExecutor(); + + virtual void add(Func func) override = 0; + virtual void add( + Func func, + std::chrono::milliseconds expiration, + Func expireCallback) = 0; + + void setThreadFactory(std::shared_ptr threadFactory) { + CHECK(numThreads() == 0); + threadFactory_ = std::move(threadFactory); + } + + std::shared_ptr getThreadFactory(void) { + return threadFactory_; + } + + size_t numThreads(); + void setNumThreads(size_t numThreads); + /* + * stop() is best effort - there is no guarantee that unexecuted tasks won't + * be executed before it returns. Specifically, IOThreadPoolExecutor's stop() + * behaves like join(). + */ + void stop(); + void join(); + + struct PoolStats { + PoolStats() : threadCount(0), idleThreadCount(0), activeThreadCount(0), + pendingTaskCount(0), totalTaskCount(0) {} + size_t threadCount, idleThreadCount, activeThreadCount; + uint64_t pendingTaskCount, totalTaskCount; + }; + + PoolStats getPoolStats(); + + struct TaskStats { + TaskStats() : expired(false), waitTime(0), runTime(0) {} + bool expired; + std::chrono::nanoseconds waitTime; + std::chrono::nanoseconds runTime; + }; + + Subscription subscribeToTaskStats( + const ObserverPtr& observer) { + return taskStatsSubject_->subscribe(observer); + } + + protected: + // Prerequisite: threadListLock_ writelocked + void addThreads(size_t n); + // Prerequisite: threadListLock_ writelocked + void removeThreads(size_t n, bool isJoin); + + struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING Thread { + explicit Thread(ThreadPoolExecutor* pool) + : id(nextId++), + handle(), + idle(true), + taskStatsSubject(pool->taskStatsSubject_) {} + + virtual ~Thread() {} + + static std::atomic nextId; + uint64_t id; + std::thread handle; + bool idle; + Baton<> startupBaton; + std::shared_ptr> taskStatsSubject; + }; + + typedef std::shared_ptr ThreadPtr; + + struct Task { + explicit Task( + Func&& func, + std::chrono::milliseconds expiration, + Func&& expireCallback); + Func func_; + TaskStats stats_; + std::chrono::steady_clock::time_point enqueueTime_; + std::chrono::milliseconds expiration_; + Func expireCallback_; + }; + + static void runTask(const ThreadPtr& thread, Task&& task); + + // The function that will be bound to pool threads. It must call + // thread->startupBaton.post() when it's ready to consume work. + virtual void threadRun(ThreadPtr thread) = 0; + + // Stop n threads and put their ThreadPtrs in the threadsStopped_ queue + // Prerequisite: threadListLock_ writelocked + virtual void stopThreads(size_t n) = 0; + + // Create a suitable Thread struct + virtual ThreadPtr makeThread() { + return std::make_shared(this); + } + + // Prerequisite: threadListLock_ readlocked + virtual uint64_t getPendingTaskCount() = 0; + + class ThreadList { + public: + void add(const ThreadPtr& state) { + auto it = std::lower_bound(vec_.begin(), vec_.end(), state, compare); + vec_.insert(it, state); + } + + void remove(const ThreadPtr& state) { + auto itPair = std::equal_range(vec_.begin(), vec_.end(), state, compare); + CHECK(itPair.first != vec_.end()); + CHECK(std::next(itPair.first) == itPair.second); + vec_.erase(itPair.first); + } + + const std::vector& get() const { + return vec_; + } + + private: + static bool compare(const ThreadPtr& ts1, const ThreadPtr& ts2) { + return ts1->id < ts2->id; + } + + std::vector vec_; + }; + + class StoppedThreadQueue : public BlockingQueue { + public: + void add(ThreadPtr item) override; + ThreadPtr take() override; + size_t size() override; + + private: + LifoSem sem_; + std::mutex mutex_; + std::queue queue_; + }; + + std::shared_ptr threadFactory_; + ThreadList threadList_; + RWSpinLock threadListLock_; + StoppedThreadQueue stoppedThreads_; + std::atomic isJoin_; // whether the current downsizing is a join + + std::shared_ptr> taskStatsSubject_; +}; + +}} // folly::wangle diff --git a/folly/wangle/concurrent/test/CodelTest.cpp b/folly/wangle/concurrent/test/CodelTest.cpp new file mode 100644 index 00000000..b61aded3 --- /dev/null +++ b/folly/wangle/concurrent/test/CodelTest.cpp @@ -0,0 +1,38 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +TEST(CodelTest, Basic) { + using std::chrono::milliseconds; + folly::wangle::Codel c; + std::this_thread::sleep_for(milliseconds(110)); + // This interval is overloaded + EXPECT_FALSE(c.overloaded(milliseconds(100))); + std::this_thread::sleep_for(milliseconds(90)); + // At least two requests must happen in an interval before they will fail + EXPECT_FALSE(c.overloaded(milliseconds(50))); + EXPECT_TRUE(c.overloaded(milliseconds(50))); + std::this_thread::sleep_for(milliseconds(110)); + // Previous interval is overloaded, but 2ms isn't enough to fail + EXPECT_FALSE(c.overloaded(milliseconds(2))); + std::this_thread::sleep_for(milliseconds(90)); + // 20 ms > target interval * 2 + EXPECT_TRUE(c.overloaded(milliseconds(20))); +} diff --git a/folly/wangle/concurrent/test/GlobalExecutorTest.cpp b/folly/wangle/concurrent/test/GlobalExecutorTest.cpp new file mode 100644 index 00000000..a601b0c1 --- /dev/null +++ b/folly/wangle/concurrent/test/GlobalExecutorTest.cpp @@ -0,0 +1,51 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +using namespace folly::wangle; + +TEST(GlobalExecutorTest, GlobalIOExecutor) { + class DummyExecutor : public IOExecutor { + public: + void add(folly::Func f) override { + count++; + } + folly::EventBase* getEventBase() override { + return nullptr; + } + int count{0}; + }; + + auto f = [](){}; + + // Don't explode, we should create the default global IOExecutor lazily here. + getIOExecutor()->add(f); + + { + DummyExecutor dummy; + setIOExecutor(&dummy); + getIOExecutor()->add(f); + // Make sure we were properly installed. + EXPECT_EQ(1, dummy.count); + } + + // Don't explode, we should restore the default global IOExecutor when dummy + // is destructed. + getIOExecutor()->add(f); +} diff --git a/folly/wangle/concurrent/test/ThreadPoolExecutorTest.cpp b/folly/wangle/concurrent/test/ThreadPoolExecutorTest.cpp new file mode 100644 index 00000000..596e2784 --- /dev/null +++ b/folly/wangle/concurrent/test/ThreadPoolExecutorTest.cpp @@ -0,0 +1,320 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +using namespace folly::wangle; +using namespace std::chrono; + +static folly::Func burnMs(uint64_t ms) { + return [ms]() { std::this_thread::sleep_for(milliseconds(ms)); }; +} + +template +static void basic() { + // Create and destroy + TPE tpe(10); +} + +TEST(ThreadPoolExecutorTest, CPUBasic) { + basic(); +} + +TEST(IOThreadPoolExecutorTest, IOBasic) { + basic(); +} + +template +static void resize() { + TPE tpe(100); + EXPECT_EQ(100, tpe.numThreads()); + tpe.setNumThreads(50); + EXPECT_EQ(50, tpe.numThreads()); + tpe.setNumThreads(150); + EXPECT_EQ(150, tpe.numThreads()); +} + +TEST(ThreadPoolExecutorTest, CPUResize) { + resize(); +} + +TEST(ThreadPoolExecutorTest, IOResize) { + resize(); +} + +template +static void stop() { + TPE tpe(1); + std::atomic completed(0); + auto f = [&](){ + burnMs(10)(); + completed++; + }; + for (int i = 0; i < 1000; i++) { + tpe.add(f); + } + tpe.stop(); + EXPECT_GT(1000, completed); +} + +// IOThreadPoolExecutor's stop() behaves like join(). Outstanding tasks belong +// to the event base, will be executed upon its destruction, and cannot be +// taken back. +template <> +void stop() { + IOThreadPoolExecutor tpe(1); + std::atomic completed(0); + auto f = [&](){ + burnMs(10)(); + completed++; + }; + for (int i = 0; i < 10; i++) { + tpe.add(f); + } + tpe.stop(); + EXPECT_EQ(10, completed); +} + +TEST(ThreadPoolExecutorTest, CPUStop) { + stop(); +} + +TEST(ThreadPoolExecutorTest, IOStop) { + stop(); +} + +template +static void join() { + TPE tpe(10); + std::atomic completed(0); + auto f = [&](){ + burnMs(1)(); + completed++; + }; + for (int i = 0; i < 1000; i++) { + tpe.add(f); + } + tpe.join(); + EXPECT_EQ(1000, completed); +} + +TEST(ThreadPoolExecutorTest, CPUJoin) { + join(); +} + +TEST(ThreadPoolExecutorTest, IOJoin) { + join(); +} + +template +static void resizeUnderLoad() { + TPE tpe(10); + std::atomic completed(0); + auto f = [&](){ + burnMs(1)(); + completed++; + }; + for (int i = 0; i < 1000; i++) { + tpe.add(f); + } + tpe.setNumThreads(5); + tpe.setNumThreads(15); + tpe.join(); + EXPECT_EQ(1000, completed); +} + +TEST(ThreadPoolExecutorTest, CPUResizeUnderLoad) { + resizeUnderLoad(); +} + +TEST(ThreadPoolExecutorTest, IOResizeUnderLoad) { + resizeUnderLoad(); +} + +template +static void poolStats() { + folly::Baton<> startBaton, endBaton; + TPE tpe(1); + auto stats = tpe.getPoolStats(); + EXPECT_EQ(1, stats.threadCount); + EXPECT_EQ(1, stats.idleThreadCount); + EXPECT_EQ(0, stats.activeThreadCount); + EXPECT_EQ(0, stats.pendingTaskCount); + EXPECT_EQ(0, stats.totalTaskCount); + tpe.add([&](){ startBaton.post(); endBaton.wait(); }); + tpe.add([&](){}); + startBaton.wait(); + stats = tpe.getPoolStats(); + EXPECT_EQ(1, stats.threadCount); + EXPECT_EQ(0, stats.idleThreadCount); + EXPECT_EQ(1, stats.activeThreadCount); + EXPECT_EQ(1, stats.pendingTaskCount); + EXPECT_EQ(2, stats.totalTaskCount); + endBaton.post(); +} + +TEST(ThreadPoolExecutorTest, CPUPoolStats) { + poolStats(); +} + +TEST(ThreadPoolExecutorTest, IOPoolStats) { + poolStats(); +} + +template +static void taskStats() { + TPE tpe(1); + std::atomic c(0); + auto s = tpe.subscribeToTaskStats( + Observer::create( + [&](ThreadPoolExecutor::TaskStats stats) { + int i = c++; + EXPECT_LT(milliseconds(0), stats.runTime); + if (i == 1) { + EXPECT_LT(milliseconds(0), stats.waitTime); + } + })); + tpe.add(burnMs(10)); + tpe.add(burnMs(10)); + tpe.join(); + EXPECT_EQ(2, c); +} + +TEST(ThreadPoolExecutorTest, CPUTaskStats) { + taskStats(); +} + +TEST(ThreadPoolExecutorTest, IOTaskStats) { + taskStats(); +} + +template +static void expiration() { + TPE tpe(1); + std::atomic statCbCount(0); + auto s = tpe.subscribeToTaskStats( + Observer::create( + [&](ThreadPoolExecutor::TaskStats stats) { + int i = statCbCount++; + if (i == 0) { + EXPECT_FALSE(stats.expired); + } else if (i == 1) { + EXPECT_TRUE(stats.expired); + } else { + FAIL(); + } + })); + std::atomic expireCbCount(0); + auto expireCb = [&] () { expireCbCount++; }; + tpe.add(burnMs(10), seconds(60), expireCb); + tpe.add(burnMs(10), milliseconds(10), expireCb); + tpe.join(); + EXPECT_EQ(2, statCbCount); + EXPECT_EQ(1, expireCbCount); +} + +TEST(ThreadPoolExecutorTest, CPUExpiration) { + expiration(); +} + +TEST(ThreadPoolExecutorTest, IOExpiration) { + expiration(); +} + +template +static void futureExecutor() { + FutureExecutor fe(2); + std::atomic c{0}; + fe.addFuture([] () { return makeFuture(42); }).then( + [&] (Try&& t) { + c++; + EXPECT_EQ(42, t.value()); + }); + fe.addFuture([] () { return 100; }).then( + [&] (Try&& t) { + c++; + EXPECT_EQ(100, t.value()); + }); + fe.addFuture([] () { return makeFuture(); }).then( + [&] (Try&& t) { + c++; + EXPECT_NO_THROW(t.value()); + }); + fe.addFuture([] () { return; }).then( + [&] (Try&& t) { + c++; + EXPECT_NO_THROW(t.value()); + }); + fe.addFuture([] () { throw std::runtime_error("oops"); }).then( + [&] (Try&& t) { + c++; + EXPECT_THROW(t.value(), std::runtime_error); + }); + // Test doing actual async work + folly::Baton<> baton; + fe.addFuture([&] () { + auto p = std::make_shared>(); + std::thread t([p](){ + burnMs(10)(); + p->setValue(42); + }); + t.detach(); + return p->getFuture(); + }).then([&] (Try&& t) { + EXPECT_EQ(42, t.value()); + c++; + baton.post(); + }); + baton.wait(); + fe.join(); + EXPECT_EQ(6, c); +} + +TEST(ThreadPoolExecutorTest, CPUFuturePool) { + futureExecutor(); +} + +TEST(ThreadPoolExecutorTest, IOFuturePool) { + futureExecutor(); +} + +TEST(ThreadPoolExecutorTest, PriorityPreemptionTest) { + bool tookLopri = false; + auto completed = 0; + auto hipri = [&] { + EXPECT_FALSE(tookLopri); + completed++; + }; + auto lopri = [&] { + tookLopri = true; + completed++; + }; + CPUThreadPoolExecutor pool(0, 2); + for (int i = 0; i < 50; i++) { + pool.add(lopri, 0); + } + for (int i = 0; i < 50; i++) { + pool.add(hipri, 1); + } + pool.setNumThreads(1); + pool.join(); + EXPECT_EQ(100, completed); +} diff --git a/folly/wangle/rx/Dummy.cpp b/folly/wangle/rx/Dummy.cpp new file mode 100644 index 00000000..02a58d4f --- /dev/null +++ b/folly/wangle/rx/Dummy.cpp @@ -0,0 +1,19 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// fbbuild is too dumb to know that .h files in the directory affect +// our project, unless we have a .cpp file in the target, in the same +// directory. diff --git a/folly/wangle/rx/Observable.h b/folly/wangle/rx/Observable.h new file mode 100644 index 00000000..b9db2f15 --- /dev/null +++ b/folly/wangle/rx/Observable.h @@ -0,0 +1,284 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace folly { namespace wangle { + +template +class Observable { + public: + Observable() : nextSubscriptionId_{1} {} + + // TODO perhaps we want to provide this #5283229 + Observable(Observable&& other) = delete; + + virtual ~Observable() { + if (unsubscriber_) { + unsubscriber_->disable(); + } + } + + // The next three methods subscribe the given Observer to this Observable. + // + // If these are called within an Observer callback, the new observer will not + // get the current update but will get subsequent updates. + // + // subscribe() returns a Subscription object. The observer will continue to + // get updates until the Subscription is destroyed. + // + // observe(ObserverPtr) creates an indefinite subscription + // + // observe(Observer*) also creates an indefinite subscription, but the + // caller is responsible for ensuring that the given Observer outlives this + // Observable. This might be useful in high performance environments where + // allocations must be kept to a minimum. Template parameter InlineObservers + // specifies how many observers can been subscribed inline without any + // allocations (it's just the size of a folly::small_vector). + virtual Subscription subscribe(ObserverPtr observer) { + return subscribeImpl(observer, false); + } + + virtual void observe(ObserverPtr observer) { + subscribeImpl(observer, true); + } + + virtual void observe(Observer* observer) { + if (inCallback_ && *inCallback_) { + if (!newObservers_) { + newObservers_.reset(new ObserverList()); + } + newObservers_->push_back(observer); + } else { + RWSpinLock::WriteHolder{&observersLock_}; + observers_.push_back(observer); + } + } + + // TODO unobserve(ObserverPtr), unobserve(Observer*) + + /// Returns a new Observable that will call back on the given Scheduler. + /// The returned Observable must outlive the parent Observable. + + // This and subscribeOn should maybe just be a first-class feature of an + // Observable, rather than making new ones whose lifetimes are tied to their + // parents. In that case it'd return a reference to this object for + // chaining. + ObservablePtr observeOn(SchedulerPtr scheduler) { + // you're right Hannes, if we have Observable::create we don't need this + // helper class. + struct ViaSubject : public Observable + { + ViaSubject(SchedulerPtr sched, + Observable* obs) + : scheduler_(sched), observable_(obs) + {} + + Subscription subscribe(ObserverPtr o) override { + return observable_->subscribe( + Observer::create( + [=](T val) { scheduler_->add([o, val] { o->onNext(val); }); }, + [=](Error e) { scheduler_->add([o, e] { o->onError(e); }); }, + [=]() { scheduler_->add([o] { o->onCompleted(); }); })); + } + + protected: + SchedulerPtr scheduler_; + Observable* observable_; + }; + + return std::make_shared(scheduler, this); + } + + /// Returns a new Observable that will subscribe to this parent Observable + /// via the given Scheduler. This can be subtle and confusing at first, see + /// http://www.introtorx.com/Content/v1.0.10621.0/15_SchedulingAndThreading.html#SubscribeOnObserveOn + std::unique_ptr subscribeOn(SchedulerPtr scheduler) { + struct Subject_ : public Subject { + public: + Subject_(SchedulerPtr s, Observable* o) : scheduler_(s), observable_(o) { + } + + Subscription subscribe(ObserverPtr o) { + scheduler_->add([=] { + observable_->subscribe(o); + }); + return Subscription(nullptr, 0); // TODO + } + + protected: + SchedulerPtr scheduler_; + Observable* observable_; + }; + + return folly::make_unique(scheduler, this); + } + + protected: + // Safely execute an operation on each observer. F must take a single + // Observer* as its argument. + template + void forEachObserver(F f) { + if (UNLIKELY(!inCallback_)) { + inCallback_.reset(new bool{false}); + } + CHECK(!(*inCallback_)); + *inCallback_ = true; + + { + RWSpinLock::ReadHolder rh(observersLock_); + for (auto o : observers_) { + f(o); + } + + for (auto& kv : subscribers_) { + f(kv.second.get()); + } + } + + if (UNLIKELY((newObservers_ && !newObservers_->empty()) || + (newSubscribers_ && !newSubscribers_->empty()) || + (oldSubscribers_ && !oldSubscribers_->empty()))) { + { + RWSpinLock::WriteHolder wh(observersLock_); + if (newObservers_) { + for (auto observer : *(newObservers_)) { + observers_.push_back(observer); + } + newObservers_->clear(); + } + if (newSubscribers_) { + for (auto& kv : *(newSubscribers_)) { + subscribers_.insert(std::move(kv)); + } + newSubscribers_->clear(); + } + if (oldSubscribers_) { + for (auto id : *(oldSubscribers_)) { + subscribers_.erase(id); + } + oldSubscribers_->clear(); + } + } + } + *inCallback_ = false; + } + + private: + Subscription subscribeImpl(ObserverPtr observer, bool indefinite) { + auto subscription = makeSubscription(indefinite); + typename SubscriberMap::value_type kv{subscription.id_, std::move(observer)}; + if (inCallback_ && *inCallback_) { + if (!newSubscribers_) { + newSubscribers_.reset(new SubscriberMap()); + } + newSubscribers_->insert(std::move(kv)); + } else { + RWSpinLock::WriteHolder{&observersLock_}; + subscribers_.insert(std::move(kv)); + } + return subscription; + } + + class Unsubscriber { + public: + explicit Unsubscriber(Observable* observable) : observable_(observable) { + CHECK(observable_); + } + + void unsubscribe(uint64_t id) { + CHECK(id > 0); + RWSpinLock::ReadHolder guard(lock_); + if (observable_) { + observable_->unsubscribe(id); + } + } + + void disable() { + RWSpinLock::WriteHolder guard(lock_); + observable_ = nullptr; + } + + private: + RWSpinLock lock_; + Observable* observable_; + }; + + std::shared_ptr unsubscriber_{nullptr}; + MicroSpinLock unsubscriberLock_{0}; + + friend class Subscription; + + void unsubscribe(uint64_t id) { + if (inCallback_ && *inCallback_) { + if (!oldSubscribers_) { + oldSubscribers_.reset(new std::vector()); + } + if (newSubscribers_) { + auto it = newSubscribers_->find(id); + if (it != newSubscribers_->end()) { + newSubscribers_->erase(it); + return; + } + } + oldSubscribers_->push_back(id); + } else { + RWSpinLock::WriteHolder{&observersLock_}; + subscribers_.erase(id); + } + } + + Subscription makeSubscription(bool indefinite) { + if (indefinite) { + return Subscription(nullptr, nextSubscriptionId_++); + } else { + if (!unsubscriber_) { + std::lock_guard guard(unsubscriberLock_); + if (!unsubscriber_) { + unsubscriber_ = std::make_shared(this); + } + } + return Subscription(unsubscriber_, nextSubscriptionId_++); + } + } + + std::atomic nextSubscriptionId_; + RWSpinLock observersLock_; + folly::ThreadLocalPtr inCallback_; + + typedef folly::small_vector*, InlineObservers> ObserverList; + ObserverList observers_; + folly::ThreadLocalPtr newObservers_; + + typedef std::map> SubscriberMap; + SubscriberMap subscribers_; + folly::ThreadLocalPtr newSubscribers_; + folly::ThreadLocalPtr> oldSubscribers_; +}; + +}} diff --git a/folly/wangle/rx/Observer.h b/folly/wangle/rx/Observer.h new file mode 100644 index 00000000..b0babce9 --- /dev/null +++ b/folly/wangle/rx/Observer.h @@ -0,0 +1,113 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace folly { namespace wangle { + +template class FunctionObserver; + +/// Observer interface. You can subclass it, or you can just use create() +/// to use std::functions. +template +struct Observer { + // These are what it means to be an Observer. + virtual void onNext(const T&) = 0; + virtual void onError(Error) = 0; + virtual void onCompleted() = 0; + + virtual ~Observer() = default; + + /// Create an Observer with std::function callbacks. Handy to make ad-hoc + /// Observers with lambdas. + /// + /// Templated for maximum perfect forwarding flexibility, but ultimately + /// whatever you pass in has to implicitly become a std::function for the + /// same signature as onNext(), onError(), and onCompleted() respectively. + /// (see the FunctionObserver typedefs) + template + static std::unique_ptr create( + N&& onNextFn, E&& onErrorFn, C&& onCompletedFn) + { + return folly::make_unique>( + std::forward(onNextFn), + std::forward(onErrorFn), + std::forward(onCompletedFn)); + } + + /// Create an Observer with only onNext and onError callbacks. + /// onCompleted will just be a no-op. + template + static std::unique_ptr create(N&& onNextFn, E&& onErrorFn) { + return folly::make_unique>( + std::forward(onNextFn), + std::forward(onErrorFn), + nullptr); + } + + /// Create an Observer with only an onNext callback. + /// onError and onCompleted will just be no-ops. + template + static std::unique_ptr create(N&& onNextFn) { + return folly::make_unique>( + std::forward(onNextFn), + nullptr, + nullptr); + } +}; + +/// An observer that uses std::function callbacks. You don't really want to +/// make one of these directly - instead use the Observer::create() methods. +template +struct FunctionObserver : public Observer { + typedef std::function OnNext; + typedef std::function OnError; + typedef std::function OnCompleted; + + /// We don't need any fancy overloads of this constructor because that's + /// what Observer::create() is for. + template + FunctionObserver(N&& n, E&& e, C&& c) + : onNext_(std::forward(n)), + onError_(std::forward(e)), + onCompleted_(std::forward(c)) + {} + + void onNext(const T& val) override { + if (onNext_) onNext_(val); + } + + void onError(Error e) override { + if (onError_) onError_(e); + } + + void onCompleted() override { + if (onCompleted_) onCompleted_(); + } + + protected: + OnNext onNext_; + OnError onError_; + OnCompleted onCompleted_; +}; + +}} diff --git a/folly/wangle/rx/README b/folly/wangle/rx/README new file mode 100644 index 00000000..ee170f35 --- /dev/null +++ b/folly/wangle/rx/README @@ -0,0 +1,36 @@ +Rx is a pattern for "functional reactive programming" that started at +Microsoft in C#, and has been reimplemented in various languages, notably +RxJava for JVM languages. + +It is basically the plural of Futures (a la Wangle). + + + singular | plural + +---------------------------------+----------------------------------- + sync | Foo getData() | std::vector getData() + async | wangle::Future getData() | wangle::Observable getData() + + +For more on Rx, I recommend these resources: + +Netflix blog post (RxJava): http://techblog.netflix.com/2013/02/rxjava-netflix-api.html +Introduction to Rx eBook (C#): http://www.introtorx.com/content/v1.0.10621.0/01_WhyRx.html +The RxJava wiki: https://github.com/Netflix/RxJava/wiki +Netflix QCon presentation: http://www.infoq.com/presentations/netflix-functional-rx +https://rx.codeplex.com/ + +There are open source C++ implementations, I haven't looked at them. They +might be the best way to go rather than writing it NIH-style. I mostly did it +as an exercise, to think through how closely we might want to integrate +something like this with Wangle, and to get a feel for how it works in C++. + +I haven't even tried to support move-only data in this version. I'm on the +fence about the usage of shared_ptr. Subject is underdeveloped. A whole rich +set of operations is obviously missing. I haven't decided how to handle +subscriptions (and therefore cancellation), but I'm pretty sure C#'s +"Disposable" is thoroughly un-C++ (opposite of RAII). So for now subscribe +returns nothing at all and you can't cancel anything ever. The whole thing is +probably riddled with lifetime corner case bugs that will come out like a +swarm of angry bees as soon as someone tries an infinite sequence, or tries to +partially observe a long sequence. I'm pretty sure subscribeOn has a bug that +I haven't tracked down yet. diff --git a/folly/wangle/rx/Subject.h b/folly/wangle/rx/Subject.h new file mode 100644 index 00000000..8717bdd4 --- /dev/null +++ b/folly/wangle/rx/Subject.h @@ -0,0 +1,46 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace folly { namespace wangle { + +/// Subject interface. A Subject is both an Observable and an Observer. There +/// is a default implementation of the Observer methods that just forwards the +/// observed events to the Subject's observers. +template +struct Subject : public Observable, public Observer { + void onNext(const T& val) override { + this->forEachObserver([&](Observer* o){ + o->onNext(val); + }); + } + void onError(Error e) override { + this->forEachObserver([&](Observer* o){ + o->onError(e); + }); + } + void onCompleted() override { + this->forEachObserver([](Observer* o){ + o->onCompleted(); + }); + } +}; + +}} diff --git a/folly/wangle/rx/Subscription.h b/folly/wangle/rx/Subscription.h new file mode 100644 index 00000000..d9885b43 --- /dev/null +++ b/folly/wangle/rx/Subscription.h @@ -0,0 +1,69 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace folly { namespace wangle { + +template +class Subscription { + public: + Subscription() {} + + Subscription(const Subscription&) = delete; + + Subscription(Subscription&& other) noexcept { + *this = std::move(other); + } + + Subscription& operator=(Subscription&& other) noexcept { + unsubscribe(); + unsubscriber_ = std::move(other.unsubscriber_); + id_ = other.id_; + other.unsubscriber_ = nullptr; + other.id_ = 0; + return *this; + } + + ~Subscription() { + unsubscribe(); + } + + private: + typedef typename Observable::Unsubscriber Unsubscriber; + + Subscription(std::shared_ptr unsubscriber, uint64_t id) + : unsubscriber_(std::move(unsubscriber)), id_(id) { + CHECK(id_ > 0); + } + + void unsubscribe() { + if (unsubscriber_ && id_ > 0) { + unsubscriber_->unsubscribe(id_); + id_ = 0; + unsubscriber_ = nullptr; + } + } + + std::shared_ptr unsubscriber_; + uint64_t id_{0}; + + friend class Observable; +}; + +}} diff --git a/folly/wangle/rx/test/RxBenchmark.cpp b/folly/wangle/rx/test/RxBenchmark.cpp new file mode 100644 index 00000000..5e3a7188 --- /dev/null +++ b/folly/wangle/rx/test/RxBenchmark.cpp @@ -0,0 +1,155 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +using namespace folly::wangle; +using folly::BenchmarkSuspender; + +static std::unique_ptr> makeObserver() { + return Observer::create([&] (int x) {}); +} + +void subscribeImpl(uint iters, int N, bool countUnsubscribe) { + for (uint iter = 0; iter < iters; iter++) { + BenchmarkSuspender bs; + Subject subject; + std::vector>> observers; + std::vector> subscriptions; + subscriptions.reserve(N); + for (int i = 0; i < N; i++) { + observers.push_back(makeObserver()); + } + bs.dismiss(); + for (int i = 0; i < N; i++) { + subscriptions.push_back(subject.subscribe(std::move(observers[i]))); + } + if (countUnsubscribe) { + subscriptions.clear(); + } + bs.rehire(); + } +} + +void subscribeAndUnsubscribe(uint iters, int N) { + subscribeImpl(iters, N, true); +} + +void subscribe(uint iters, int N) { + subscribeImpl(iters, N, false); +} + +void observe(uint iters, int N) { + for (uint iter = 0; iter < iters; iter++) { + BenchmarkSuspender bs; + Subject subject; + std::vector>> observers; + for (int i = 0; i < N; i++) { + observers.push_back(makeObserver()); + } + bs.dismiss(); + for (int i = 0; i < N; i++) { + subject.observe(std::move(observers[i])); + } + bs.rehire(); + } +} + +void inlineObserve(uint iters, int N) { + for (uint iter = 0; iter < iters; iter++) { + BenchmarkSuspender bs; + Subject subject; + std::vector*> observers; + for (int i = 0; i < N; i++) { + observers.push_back(makeObserver().release()); + } + bs.dismiss(); + for (int i = 0; i < N; i++) { + subject.observe(observers[i]); + } + bs.rehire(); + for (int i = 0; i < N; i++) { + delete observers[i]; + } + } +} + +void notifySubscribers(uint iters, int N) { + for (uint iter = 0; iter < iters; iter++) { + BenchmarkSuspender bs; + Subject subject; + std::vector>> observers; + std::vector> subscriptions; + subscriptions.reserve(N); + for (int i = 0; i < N; i++) { + observers.push_back(makeObserver()); + } + for (int i = 0; i < N; i++) { + subscriptions.push_back(subject.subscribe(std::move(observers[i]))); + } + bs.dismiss(); + subject.onNext(42); + bs.rehire(); + } +} + +void notifyInlineObservers(uint iters, int N) { + for (uint iter = 0; iter < iters; iter++) { + BenchmarkSuspender bs; + Subject subject; + std::vector*> observers; + for (int i = 0; i < N; i++) { + observers.push_back(makeObserver().release()); + } + for (int i = 0; i < N; i++) { + subject.observe(observers[i]); + } + bs.dismiss(); + subject.onNext(42); + bs.rehire(); + } +} + +BENCHMARK_PARAM(subscribeAndUnsubscribe, 1); +BENCHMARK_RELATIVE_PARAM(subscribe, 1); +BENCHMARK_RELATIVE_PARAM(observe, 1); +BENCHMARK_RELATIVE_PARAM(inlineObserve, 1); + +BENCHMARK_DRAW_LINE(); + +BENCHMARK_PARAM(subscribeAndUnsubscribe, 1000); +BENCHMARK_RELATIVE_PARAM(subscribe, 1000); +BENCHMARK_RELATIVE_PARAM(observe, 1000); +BENCHMARK_RELATIVE_PARAM(inlineObserve, 1000); + +BENCHMARK_DRAW_LINE(); + +BENCHMARK_PARAM(notifySubscribers, 1); +BENCHMARK_RELATIVE_PARAM(notifyInlineObservers, 1); + +BENCHMARK_DRAW_LINE(); + +BENCHMARK_PARAM(notifySubscribers, 1000); +BENCHMARK_RELATIVE_PARAM(notifyInlineObservers, 1000); + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + folly::runBenchmarks(); + return 0; +} diff --git a/folly/wangle/rx/test/RxTest.cpp b/folly/wangle/rx/test/RxTest.cpp new file mode 100644 index 00000000..d2003fcf --- /dev/null +++ b/folly/wangle/rx/test/RxTest.cpp @@ -0,0 +1,195 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +using namespace folly::wangle; + +static std::unique_ptr> incrementer(int& counter) { + return Observer::create([&] (int x) { + counter++; + }); +} + +TEST(RxTest, Observe) { + Subject subject; + auto count = 0; + subject.observe(incrementer(count)); + subject.onNext(1); + EXPECT_EQ(1, count); +} + +TEST(RxTest, ObserveInline) { + Subject subject; + auto count = 0; + auto o = incrementer(count).release(); + subject.observe(o); + subject.onNext(1); + EXPECT_EQ(1, count); + delete o; +} + +TEST(RxTest, Subscription) { + Subject subject; + auto count = 0; + { + auto s = subject.subscribe(incrementer(count)); + subject.onNext(1); + } + // The subscription has gone out of scope so no one should get this. + subject.onNext(2); + EXPECT_EQ(1, count); +} + +TEST(RxTest, SubscriptionMove) { + Subject subject; + auto count = 0; + auto s = subject.subscribe(incrementer(count)); + auto s2 = subject.subscribe(incrementer(count)); + s2 = std::move(s); + subject.onNext(1); + Subscription s3(std::move(s2)); + subject.onNext(2); + EXPECT_EQ(2, count); +} + +TEST(RxTest, SubscriptionOutlivesSubject) { + Subscription s; + { + Subject subject; + s = subject.subscribe(Observer::create([](int){})); + } + // Don't explode when s is destroyed +} + +TEST(RxTest, SubscribeDuringCallback) { + // A subscriber who was subscribed in the course of a callback should get + // subsequent updates but not the current update. + Subject subject; + int outerCount = 0, innerCount = 0; + Subscription s1, s2; + s1 = subject.subscribe(Observer::create([&] (int x) { + outerCount++; + s2 = subject.subscribe(incrementer(innerCount)); + })); + subject.onNext(42); + subject.onNext(0xDEADBEEF); + EXPECT_EQ(2, outerCount); + EXPECT_EQ(1, innerCount); +} + +TEST(RxTest, ObserveDuringCallback) { + Subject subject; + int outerCount = 0, innerCount = 0; + subject.observe(Observer::create([&] (int x) { + outerCount++; + subject.observe(incrementer(innerCount)); + })); + subject.onNext(42); + subject.onNext(0xDEADBEEF); + EXPECT_EQ(2, outerCount); + EXPECT_EQ(1, innerCount); +} + +TEST(RxTest, ObserveInlineDuringCallback) { + Subject subject; + int outerCount = 0, innerCount = 0; + auto innerO = incrementer(innerCount).release(); + auto outerO = Observer::create([&] (int x) { + outerCount++; + subject.observe(innerO); + }).release(); + subject.observe(outerO); + subject.onNext(42); + subject.onNext(0xDEADBEEF); + EXPECT_EQ(2, outerCount); + EXPECT_EQ(1, innerCount); + delete innerO; + delete outerO; +} + +TEST(RxTest, UnsubscribeDuringCallback) { + // A subscriber who was unsubscribed in the course of a callback should get + // the current update but not subsequent ones + Subject subject; + int count1 = 0, count2 = 0; + auto s1 = subject.subscribe(incrementer(count1)); + auto s2 = subject.subscribe(Observer::create([&] (int x) { + count2++; + s1.~Subscription(); + })); + subject.onNext(1); + subject.onNext(2); + EXPECT_EQ(1, count1); + EXPECT_EQ(2, count2); +} + +TEST(RxTest, SubscribeUnsubscribeDuringCallback) { + // A subscriber who was subscribed and unsubscribed in the course of a + // callback should not get any updates + Subject subject; + int outerCount = 0, innerCount = 0; + auto s2 = subject.subscribe(Observer::create([&] (int x) { + outerCount++; + auto s2 = subject.subscribe(incrementer(innerCount)); + })); + subject.onNext(1); + subject.onNext(2); + EXPECT_EQ(2, outerCount); + EXPECT_EQ(0, innerCount); +} + +// Move only type +typedef std::unique_ptr MO; +static MO makeMO() { return folly::make_unique(1); } +template +static ObserverPtr makeMOObserver() { + return Observer::create([](const T& mo) { + EXPECT_EQ(1, *mo); + }); +} + +TEST(RxTest, MoveOnlyRvalue) { + Subject subject; + auto s1 = subject.subscribe(makeMOObserver()); + auto s2 = subject.subscribe(makeMOObserver()); + auto mo = makeMO(); + // Can't bind lvalues to rvalue references + // subject.onNext(mo); + subject.onNext(std::move(mo)); + subject.onNext(makeMO()); +} + +// Copy only type +struct CO { + CO() = default; + CO(const CO&) = default; + CO(CO&&) = delete; +}; + +template +static ObserverPtr makeCOObserver() { + return Observer::create([](const T& mo) {}); +} + +TEST(RxTest, CopyOnly) { + Subject subject; + auto s1 = subject.subscribe(makeCOObserver()); + CO co; + subject.onNext(co); +} diff --git a/folly/wangle/rx/types.h b/folly/wangle/rx/types.h new file mode 100644 index 00000000..27c2f3b7 --- /dev/null +++ b/folly/wangle/rx/types.h @@ -0,0 +1,35 @@ +/* + * Copyright 2014 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace folly { namespace wangle { + typedef folly::exception_wrapper Error; + // The Executor is basically an rx Scheduler (by design). So just + // alias it. + typedef std::shared_ptr SchedulerPtr; + + template struct Observable; + template struct Observer; + template struct Subject; + + template using ObservablePtr = std::shared_ptr>; + template using ObserverPtr = std::shared_ptr>; + template using SubjectPtr = std::shared_ptr>; +}} diff --git a/folly/wangle/ssl/ClientHelloExtStats.h b/folly/wangle/ssl/ClientHelloExtStats.h new file mode 100644 index 00000000..a95ee0c6 --- /dev/null +++ b/folly/wangle/ssl/ClientHelloExtStats.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +namespace folly { + +class ClientHelloExtStats { + public: + virtual ~ClientHelloExtStats() noexcept {} + + // client hello + virtual void recordAbsentHostname() noexcept = 0; + virtual void recordMatch() noexcept = 0; + virtual void recordNotMatch() noexcept = 0; +}; + +} diff --git a/folly/wangle/ssl/DHParam.h b/folly/wangle/ssl/DHParam.h new file mode 100644 index 00000000..561d5691 --- /dev/null +++ b/folly/wangle/ssl/DHParam.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include + +// The following was auto-generated by +// openssl dhparam -C 2048 +DH *get_dh2048() + { + static unsigned char dh2048_p[]={ + 0xF8,0x87,0xA5,0x15,0x98,0x35,0x20,0x1E,0xF5,0x81,0xE5,0x95, + 0x1B,0xE4,0x54,0xEA,0x53,0xF5,0xE7,0x26,0x30,0x03,0x06,0x79, + 0x3C,0xC1,0x0B,0xAD,0x3B,0x59,0x3C,0x61,0x13,0x03,0x7B,0x02, + 0x70,0xDE,0xC1,0x20,0x11,0x9E,0x94,0x13,0x50,0xF7,0x62,0xFC, + 0x99,0x0D,0xC1,0x12,0x6E,0x03,0x95,0xA3,0x57,0xC7,0x3C,0xB8, + 0x6B,0x40,0x56,0x65,0x70,0xFB,0x7A,0xE9,0x02,0xEC,0xD2,0xB6, + 0x54,0xD7,0x34,0xAD,0x3D,0x9E,0x11,0x61,0x53,0xBE,0xEA,0xB8, + 0x17,0x48,0xA8,0xDC,0x70,0xAE,0x65,0x99,0x3F,0x82,0x4C,0xFF, + 0x6A,0xC9,0xFA,0xB1,0xFA,0xE4,0x4F,0x5D,0xA4,0x05,0xC2,0x8E, + 0x55,0xC0,0xB1,0x1D,0xCC,0x17,0xF3,0xFA,0x65,0xD8,0x6B,0x09, + 0x13,0x01,0x2A,0x39,0xF1,0x86,0x73,0xE3,0x7A,0xC8,0xDB,0x7D, + 0xDA,0x1C,0xA1,0x2D,0xBA,0x2C,0x00,0x6B,0x2C,0x55,0x28,0x2B, + 0xD5,0xF5,0x3C,0x9F,0x50,0xA7,0xB7,0x28,0x9F,0x22,0xD5,0x3A, + 0xC4,0x53,0x01,0xC9,0xF3,0x69,0xB1,0x8D,0x01,0x36,0xF8,0xA8, + 0x89,0xCA,0x2E,0x72,0xBC,0x36,0x3A,0x42,0xC1,0x06,0xD6,0x0E, + 0xCB,0x4D,0x5C,0x1F,0xE4,0xA1,0x17,0xBF,0x55,0x64,0x1B,0xB4, + 0x52,0xEC,0x15,0xED,0x32,0xB1,0x81,0x07,0xC9,0x71,0x25,0xF9, + 0x4D,0x48,0x3D,0x18,0xF4,0x12,0x09,0x32,0xC4,0x0B,0x7A,0x4E, + 0x83,0xC3,0x10,0x90,0x51,0x2E,0xBE,0x87,0xF9,0xDE,0xB4,0xE6, + 0x3C,0x29,0xB5,0x32,0x01,0x9D,0x95,0x04,0xBD,0x42,0x89,0xFD, + 0x21,0xEB,0xE9,0x88,0x5A,0x27,0xBB,0x31,0xC4,0x26,0x99,0xAB, + 0x8C,0xA1,0x76,0xDB, + }; + static unsigned char dh2048_g[]={ + 0x02, + }; + DH *dh; + + if ((dh=DH_new()) == nullptr) return(nullptr); + dh->p=BN_bin2bn(dh2048_p,(int)sizeof(dh2048_p),nullptr); + dh->g=BN_bin2bn(dh2048_g,(int)sizeof(dh2048_g),nullptr); + if ((dh->p == nullptr) || (dh->g == nullptr)) + { DH_free(dh); return(nullptr); } + return(dh); + } diff --git a/folly/wangle/ssl/PasswordInFile.cpp b/folly/wangle/ssl/PasswordInFile.cpp new file mode 100644 index 00000000..77ec6235 --- /dev/null +++ b/folly/wangle/ssl/PasswordInFile.cpp @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include + +#include + +using namespace std; + +namespace folly { + +PasswordInFile::PasswordInFile(const string& file) + : fileName_(file) { + folly::readFile(file.c_str(), password_); + auto p = password_.find('\0'); + if (p != std::string::npos) { + password_.erase(p); + } +} + +PasswordInFile::~PasswordInFile() { + OPENSSL_cleanse((char *)password_.data(), password_.length()); +} + +} diff --git a/folly/wangle/ssl/PasswordInFile.h b/folly/wangle/ssl/PasswordInFile.h new file mode 100644 index 00000000..b0a09227 --- /dev/null +++ b/folly/wangle/ssl/PasswordInFile.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include // PasswordCollector + +namespace folly { + +class PasswordInFile: public folly::PasswordCollector { + public: + explicit PasswordInFile(const std::string& file); + ~PasswordInFile(); + + void getPassword(std::string& password, int size) override { + password = password_; + } + + const char* getPasswordStr() const { + return password_.c_str(); + } + + std::string describe() const override { + return fileName_; + } + + protected: + std::string fileName_; + std::string password_; +}; + +} diff --git a/folly/wangle/ssl/SSLCacheOptions.h b/folly/wangle/ssl/SSLCacheOptions.h new file mode 100644 index 00000000..56175378 --- /dev/null +++ b/folly/wangle/ssl/SSLCacheOptions.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include +#include + +namespace folly { + +struct SSLCacheOptions { + std::chrono::seconds sslCacheTimeout; + uint64_t maxSSLCacheSize; + uint64_t sslCacheFlushSize; +}; + +} diff --git a/folly/wangle/ssl/SSLCacheProvider.h b/folly/wangle/ssl/SSLCacheProvider.h new file mode 100644 index 00000000..feecca46 --- /dev/null +++ b/folly/wangle/ssl/SSLCacheProvider.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include + +namespace folly { + +class SSLSessionCacheManager; + +/** + * Interface to be implemented by providers of external session caches + */ +class SSLCacheProvider { +public: + /** + * Context saved during an external cache request that is used to + * resume the waiting client. + */ + typedef struct { + std::string sessionId; + SSL_SESSION* session; + SSLSessionCacheManager* manager; + AsyncSSLSocket* sslSocket; + std::unique_ptr< + folly::DelayedDestruction::DestructorGuard> guard; + } CacheContext; + + virtual ~SSLCacheProvider() {} + + /** + * Store a session in the external cache. + * @param sessionId Identifier that can be used later to fetch the + * session with getAsync() + * @param value Serialized session to store + * @param expiration Relative expiration time: seconds from now + * @return true if the storing of the session is initiated successfully + * (though not necessarily completed; the completion may + * happen either before or after this method returns), or + * false if the storing cannot be initiated due to an error. + */ + virtual bool setAsync(const std::string& sessionId, + const std::string& value, + std::chrono::seconds expiration) = 0; + + /** + * Retrieve a session from the external cache. When done, call + * the cache manager's onGetSuccess() or onGetFailure() callback. + * @param sessionId Session ID to fetch + * @param context Data to pass back to the SSLSessionCacheManager + * in the completion callback + * @return true if the lookup of the session is initiated successfully + * (though not necessarily completed; the completion may + * happen either before or after this method returns), or + * false if the lookup cannot be initiated due to an error. + */ + virtual bool getAsync(const std::string& sessionId, + CacheContext* context) = 0; + +}; + +} diff --git a/folly/wangle/ssl/SSLContextConfig.h b/folly/wangle/ssl/SSLContextConfig.h new file mode 100644 index 00000000..bd3f8044 --- /dev/null +++ b/folly/wangle/ssl/SSLContextConfig.h @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include +#include +#include + +/** + * SSLContextConfig helps to describe the configs/options for + * a SSL_CTX. For example: + * + * 1. Filename of X509, private key and its password. + * 2. ciphers list + * 3. NPN list + * 4. Is session cache enabled? + * 5. Is it the default X509 in SNI operation? + * 6. .... and a few more + */ +namespace folly { + +struct SSLContextConfig { + SSLContextConfig() {} + ~SSLContextConfig() {} + + struct CertificateInfo { + std::string certPath; + std::string keyPath; + std::string passwordPath; + }; + + /** + * Helpers to set/add a certificate + */ + void setCertificate(const std::string& certPath, + const std::string& keyPath, + const std::string& passwordPath) { + certificates.clear(); + addCertificate(certPath, keyPath, passwordPath); + } + + void addCertificate(const std::string& certPath, + const std::string& keyPath, + const std::string& passwordPath) { + certificates.emplace_back(CertificateInfo{certPath, keyPath, passwordPath}); + } + + /** + * Set the optional list of protocols to advertise via TLS + * Next Protocol Negotiation. An empty list means NPN is not enabled. + */ + void setNextProtocols(const std::list& inNextProtocols) { + nextProtocols.clear(); + nextProtocols.push_back({1, inNextProtocols}); + } + + typedef std::function SNINoMatchFn; + + std::vector certificates; + folly::SSLContext::SSLVersion sslVersion{ + folly::SSLContext::TLSv1}; + bool sessionCacheEnabled{true}; + bool sessionTicketEnabled{true}; + bool clientHelloParsingEnabled{false}; + std::string sslCiphers{ + "ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:" + "ECDHE-ECDSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES128-GCM-SHA256:" + "ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-SHA:ECDHE-RSA-AES256-SHA:" + "AES128-GCM-SHA256:AES256-GCM-SHA384:AES128-SHA:AES256-SHA:" + "ECDHE-ECDSA-RC4-SHA:ECDHE-RSA-RC4-SHA:RC4-SHA:RC4-MD5:" + "ECDHE-RSA-DES-CBC3-SHA:DES-CBC3-SHA"}; + std::string eccCurveName; + // Ciphers to negotiate if TLS version >= 1.1 + std::string tls11Ciphers{""}; + // Weighted lists of NPN strings to advertise + std::list + nextProtocols; + bool isLocalPrivateKey{true}; + // Should this SSLContextConfig be the default for SNI purposes + bool isDefault{false}; + // Callback function to invoke when there are no matching certificates + // (will only be invoked once) + SNINoMatchFn sniNoMatchFn; + // File containing trusted CA's to validate client certificates + std::string clientCAFile; +}; + +} diff --git a/folly/wangle/ssl/SSLContextManager.cpp b/folly/wangle/ssl/SSLContextManager.cpp new file mode 100644 index 00000000..eb90ef3a --- /dev/null +++ b/folly/wangle/ssl/SSLContextManager.cpp @@ -0,0 +1,651 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#define OPENSSL_MISSING_FEATURE(name) \ +do { \ + throw std::runtime_error("missing " #name " support in openssl"); \ +} while(0) + + +using std::string; +using std::shared_ptr; + +/** + * SSLContextManager helps to create and manage all SSL_CTX, + * SSLSessionCacheManager and TLSTicketManager for a listening + * VIP:PORT. (Note, in SNI, a listening VIP:PORT can have >1 SSL_CTX(s)). + * + * Other responsibilities: + * 1. It also handles the SSL_CTX selection after getting the tlsext_hostname + * in the client hello message. + * + * Usage: + * 1. Each listening VIP:PORT serving SSL should have one SSLContextManager. + * It maps to Acceptor in the wangle vocabulary. + * + * 2. Create a SSLContextConfig object (e.g. by parsing the JSON config). + * + * 3. Call SSLContextManager::addSSLContextConfig() which will + * then create and configure the SSL_CTX + * + * Note: Each Acceptor, with SSL support, should have one SSLContextManager to + * manage all SSL_CTX for the VIP:PORT. + */ + +namespace folly { + +namespace { + +X509* getX509(SSL_CTX* ctx) { + SSL* ssl = SSL_new(ctx); + SSL_set_connect_state(ssl); + X509* x509 = SSL_get_certificate(ssl); + CRYPTO_add(&x509->references, 1, CRYPTO_LOCK_X509); + SSL_free(ssl); + return x509; +} + +void set_key_from_curve(SSL_CTX* ctx, const std::string& curveName) { +#if OPENSSL_VERSION_NUMBER >= 0x0090800fL +#ifndef OPENSSL_NO_ECDH + EC_KEY* ecdh = nullptr; + int nid; + + /* + * Elliptic-Curve Diffie-Hellman parameters are either "named curves" + * from RFC 4492 section 5.1.1, or explicitly described curves over + * binary fields. OpenSSL only supports the "named curves", which provide + * maximum interoperability. + */ + + nid = OBJ_sn2nid(curveName.c_str()); + if (nid == 0) { + LOG(FATAL) << "Unknown curve name:" << curveName.c_str(); + return; + } + ecdh = EC_KEY_new_by_curve_name(nid); + if (ecdh == nullptr) { + LOG(FATAL) << "Unable to create curve:" << curveName.c_str(); + return; + } + + SSL_CTX_set_tmp_ecdh(ctx, ecdh); + EC_KEY_free(ecdh); +#endif +#endif +} + +// Helper to create TLSTicketKeyManger and aware of the needed openssl +// version/feature. +std::unique_ptr createTicketManagerHelper( + std::shared_ptr ctx, + const TLSTicketKeySeeds* ticketSeeds, + const SSLContextConfig& ctxConfig, + SSLStats* stats) { + + std::unique_ptr ticketManager; +#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB + if (ticketSeeds && ctxConfig.sessionTicketEnabled) { + ticketManager = folly::make_unique(ctx.get(), stats); + ticketManager->setTLSTicketKeySeeds( + ticketSeeds->oldSeeds, + ticketSeeds->currentSeeds, + ticketSeeds->newSeeds); + } else { + ctx->setOptions(SSL_OP_NO_TICKET); + } +#else + if (ticketSeeds && ctxConfig.sessionTicketEnabled) { + OPENSSL_MISSING_FEATURE(TLSTicket); + } +#endif + return ticketManager; +} + +std::string flattenList(const std::list& list) { + std::string s; + bool first = true; + for (auto& item : list) { + if (first) { + first = false; + } else { + s.append(", "); + } + s.append(item); + } + return s; +} + +} + +SSLContextManager::~SSLContextManager() {} + +SSLContextManager::SSLContextManager( + EventBase* eventBase, + const std::string& vipName, + bool strict, + SSLStats* stats) : + stats_(stats), + eventBase_(eventBase), + strict_(strict) { +} + +void SSLContextManager::addSSLContextConfig( + const SSLContextConfig& ctxConfig, + const SSLCacheOptions& cacheOptions, + const TLSTicketKeySeeds* ticketSeeds, + const folly::SocketAddress& vipAddress, + const std::shared_ptr& externalCache) { + + unsigned numCerts = 0; + std::string commonName; + std::string lastCertPath; + std::unique_ptr> subjectAltName; + auto sslCtx = std::make_shared(ctxConfig.sslVersion); + for (const auto& cert : ctxConfig.certificates) { + try { + sslCtx->loadCertificate(cert.certPath.c_str()); + } catch (const std::exception& ex) { + // The exception isn't very useful without the certificate path name, + // so throw a new exception that includes the path to the certificate. + string msg = folly::to("error loading SSL certificate ", + cert.certPath, ": ", + folly::exceptionStr(ex)); + LOG(ERROR) << msg; + throw std::runtime_error(msg); + } + + // Verify that the Common Name and (if present) Subject Alternative Names + // are the same for all the certs specified for the SSL context. + numCerts++; + X509* x509 = getX509(sslCtx->getSSLCtx()); + auto guard = folly::makeGuard([x509] { X509_free(x509); }); + auto cn = SSLUtil::getCommonName(x509); + if (!cn) { + throw std::runtime_error(folly::to("Cannot get CN for X509 ", + cert.certPath)); + } + auto altName = SSLUtil::getSubjectAltName(x509); + VLOG(2) << "cert " << cert.certPath << " CN: " << *cn; + if (altName) { + altName->sort(); + VLOG(2) << "cert " << cert.certPath << " SAN: " << flattenList(*altName); + } else { + VLOG(2) << "cert " << cert.certPath << " SAN: " << "{none}"; + } + if (numCerts == 1) { + commonName = *cn; + subjectAltName = std::move(altName); + } else { + if (commonName != *cn) { + throw std::runtime_error(folly::to("X509 ", cert.certPath, + " does not have same CN as ", + lastCertPath)); + } + if (altName == nullptr) { + if (subjectAltName != nullptr) { + throw std::runtime_error(folly::to("X509 ", cert.certPath, + " does not have same SAN as ", + lastCertPath)); + } + } else { + if ((subjectAltName == nullptr) || (*altName != *subjectAltName)) { + throw std::runtime_error(folly::to("X509 ", cert.certPath, + " does not have same SAN as ", + lastCertPath)); + } + } + } + lastCertPath = cert.certPath; + + // TODO t4438250 - Add ECDSA support to the crypto_ssl offload server + // so we can avoid storing the ECDSA private key in the + // address space of the Internet-facing process. For + // now, if cert name includes "-EC" to denote elliptic + // curve, we load its private key even if the server as + // a whole has been configured for async crypto. + if (ctxConfig.isLocalPrivateKey || + (cert.certPath.find("-EC") != std::string::npos)) { + // The private key lives in the same process + + // This needs to be called before loadPrivateKey(). + if (!cert.passwordPath.empty()) { + auto sslPassword = std::make_shared(cert.passwordPath); + sslCtx->passwordCollector(sslPassword); + } + + try { + sslCtx->loadPrivateKey(cert.keyPath.c_str()); + } catch (const std::exception& ex) { + // Throw an error that includes the key path, so the user can tell + // which key had a problem. + string msg = folly::to("error loading private SSL key ", + cert.keyPath, ": ", + folly::exceptionStr(ex)); + LOG(ERROR) << msg; + throw std::runtime_error(msg); + } + } + } + if (!ctxConfig.isLocalPrivateKey) { + enableAsyncCrypto(sslCtx); + } + + // Let the server pick the highest performing cipher from among the client's + // choices. + // + // Let's use a unique private key for all DH key exchanges. + // + // Because some old implementations choke on empty fragments, most SSL + // applications disable them (it's part of SSL_OP_ALL). This + // will improve performance and decrease write buffer fragmentation. + sslCtx->setOptions(SSL_OP_CIPHER_SERVER_PREFERENCE | + SSL_OP_SINGLE_DH_USE | + SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS); + + // Configure SSL ciphers list + if (!ctxConfig.tls11Ciphers.empty()) { + // FIXME: create a dummy SSL_CTX for cipher testing purpose? It can + // remove the ordering dependency + + // Test to see if the specified TLS1.1 ciphers are valid. Note that + // these will be overwritten by the ciphers() call below. + sslCtx->setCiphersOrThrow(ctxConfig.tls11Ciphers); + } + + // Important that we do this *after* checking the TLS1.1 ciphers above, + // since we test their validity by actually setting them. + sslCtx->ciphers(ctxConfig.sslCiphers); + + // Use a fix DH param + DH* dh = get_dh2048(); + SSL_CTX_set_tmp_dh(sslCtx->getSSLCtx(), dh); + DH_free(dh); + + const string& curve = ctxConfig.eccCurveName; + if (!curve.empty()) { + set_key_from_curve(sslCtx->getSSLCtx(), curve); + } + + if (!ctxConfig.clientCAFile.empty()) { + try { + sslCtx->setVerificationOption(SSLContext::VERIFY_REQ_CLIENT_CERT); + sslCtx->loadTrustedCertificates(ctxConfig.clientCAFile.c_str()); + sslCtx->loadClientCAList(ctxConfig.clientCAFile.c_str()); + } catch (const std::exception& ex) { + string msg = folly::to("error loading client CA", + ctxConfig.clientCAFile, ": ", + folly::exceptionStr(ex)); + LOG(ERROR) << msg; + throw std::runtime_error(msg); + } + } + + // - start - SSL session cache config + // the internal cache never does what we want (per-thread-per-vip). + // Disable it. SSLSessionCacheManager will set it appropriately. + SSL_CTX_set_session_cache_mode(sslCtx->getSSLCtx(), SSL_SESS_CACHE_OFF); + SSL_CTX_set_timeout(sslCtx->getSSLCtx(), + cacheOptions.sslCacheTimeout.count()); + std::unique_ptr sessionCacheManager; + if (ctxConfig.sessionCacheEnabled && + cacheOptions.maxSSLCacheSize > 0 && + cacheOptions.sslCacheFlushSize > 0) { + sessionCacheManager = + folly::make_unique( + cacheOptions.maxSSLCacheSize, + cacheOptions.sslCacheFlushSize, + sslCtx.get(), + vipAddress, + commonName, + eventBase_, + stats_, + externalCache); + } + // - end - SSL session cache config + + std::unique_ptr ticketManager = + createTicketManagerHelper(sslCtx, ticketSeeds, ctxConfig, stats_); + + // finalize sslCtx setup by the individual features supported by openssl + ctxSetupByOpensslFeature(sslCtx, ctxConfig); + + try { + insert(sslCtx, + std::move(sessionCacheManager), + std::move(ticketManager), + ctxConfig.isDefault); + } catch (const std::exception& ex) { + string msg = folly::to("Error adding certificate : ", + folly::exceptionStr(ex)); + LOG(ERROR) << msg; + throw std::runtime_error(msg); + } + +} + +#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK +SSLContext::ServerNameCallbackResult +SSLContextManager::serverNameCallback(SSL* ssl) { + shared_ptr ctx; + + const char* sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); + if (!sn) { + VLOG(6) << "Server Name (tlsext_hostname) is missing"; + if (clientHelloTLSExtStats_) { + clientHelloTLSExtStats_->recordAbsentHostname(); + } + return SSLContext::SERVER_NAME_NOT_FOUND; + } + size_t snLen = strlen(sn); + VLOG(6) << "Server Name (SNI TLS extension): '" << sn << "' "; + + // FIXME: This code breaks the abstraction. Suggestion? + AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl); + CHECK(sslSocket); + + DNString dnstr(sn, snLen); + + uint32_t count = 0; + do { + // Try exact match first + ctx = getSSLCtx(dnstr); + if (ctx) { + sslSocket->switchServerSSLContext(ctx); + if (clientHelloTLSExtStats_) { + clientHelloTLSExtStats_->recordMatch(); + } + return SSLContext::SERVER_NAME_FOUND; + } + + ctx = getSSLCtxBySuffix(dnstr); + if (ctx) { + sslSocket->switchServerSSLContext(ctx); + if (clientHelloTLSExtStats_) { + clientHelloTLSExtStats_->recordMatch(); + } + return SSLContext::SERVER_NAME_FOUND; + } + + // Give the noMatchFn one chance to add the correct cert + } + while (count++ == 0 && noMatchFn_ && noMatchFn_(sn)); + + VLOG(6) << folly::stringPrintf("Cannot find a SSL_CTX for \"%s\"", sn); + + if (clientHelloTLSExtStats_) { + clientHelloTLSExtStats_->recordNotMatch(); + } + return SSLContext::SERVER_NAME_NOT_FOUND; +} +#endif + +// Consolidate all SSL_CTX setup which depends on openssl version/feature +void +SSLContextManager::ctxSetupByOpensslFeature( + shared_ptr sslCtx, + const SSLContextConfig& ctxConfig) { + // Disable compression - profiling shows this to be very expensive in + // terms of CPU and memory consumption. + // +#ifdef SSL_OP_NO_COMPRESSION + sslCtx->setOptions(SSL_OP_NO_COMPRESSION); +#endif + + // Enable early release of SSL buffers to reduce the memory footprint +#ifdef SSL_MODE_RELEASE_BUFFERS + sslCtx->getSSLCtx()->mode |= SSL_MODE_RELEASE_BUFFERS; +#endif +#ifdef SSL_MODE_EARLY_RELEASE_BBIO + sslCtx->getSSLCtx()->mode |= SSL_MODE_EARLY_RELEASE_BBIO; +#endif + + // This number should (probably) correspond to HTTPSession::kMaxReadSize + // For now, this number must also be large enough to accommodate our + // largest certificate, because some older clients (IE6/7) require the + // cert to be in a single fragment. +#ifdef SSL_CTRL_SET_MAX_SEND_FRAGMENT + SSL_CTX_set_max_send_fragment(sslCtx->getSSLCtx(), 8000); +#endif + + // Specify cipher(s) to be used for TLS1.1 client + if (!ctxConfig.tls11Ciphers.empty()) { +#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK + // Specified TLS1.1 ciphers are valid + sslCtx->addClientHelloCallback( + std::bind( + &SSLContext::switchCiphersIfTLS11, + sslCtx.get(), + std::placeholders::_1, + ctxConfig.tls11Ciphers + ) + ); +#else + OPENSSL_MISSING_FEATURE(SNI); +#endif + } + + // NPN (Next Protocol Negotiation) + if (!ctxConfig.nextProtocols.empty()) { +#ifdef OPENSSL_NPN_NEGOTIATED + sslCtx->setRandomizedAdvertisedNextProtocols(ctxConfig.nextProtocols); +#else + OPENSSL_MISSING_FEATURE(NPN); +#endif + } + + // SNI +#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK + noMatchFn_ = ctxConfig.sniNoMatchFn; + if (ctxConfig.isDefault) { + if (defaultCtx_) { + throw std::runtime_error(">1 X509 is set as default"); + } + + defaultCtx_ = sslCtx; + defaultCtx_->setServerNameCallback( + std::bind(&SSLContextManager::serverNameCallback, this, + std::placeholders::_1)); + } +#else + if (ctxs_.size() > 1) { + OPENSSL_MISSING_FEATURE(SNI); + } +#endif +} + +void +SSLContextManager::insert(shared_ptr sslCtx, + std::unique_ptr smanager, + std::unique_ptr tmanager, + bool defaultFallback) { + X509* x509 = getX509(sslCtx->getSSLCtx()); + auto guard = folly::makeGuard([x509] { X509_free(x509); }); + auto cn = SSLUtil::getCommonName(x509); + if (!cn) { + throw std::runtime_error("Cannot get CN"); + } + + /** + * Some notes from RFC 2818. Only for future quick references in case of bugs + * + * RFC 2818 section 3.1: + * "...... + * If a subjectAltName extension of type dNSName is present, that MUST + * be used as the identity. Otherwise, the (most specific) Common Name + * field in the Subject field of the certificate MUST be used. Although + * the use of the Common Name is existing practice, it is deprecated and + * Certification Authorities are encouraged to use the dNSName instead. + * ...... + * In some cases, the URI is specified as an IP address rather than a + * hostname. In this case, the iPAddress subjectAltName must be present + * in the certificate and must exactly match the IP in the URI. + * ......" + */ + + // Not sure if we ever get this kind of X509... + // If we do, assume '*' is always in the CN and ignore all subject alternative + // names. + if (cn->length() == 1 && (*cn)[0] == '*') { + if (!defaultFallback) { + throw std::runtime_error("STAR X509 is not the default"); + } + ctxs_.emplace_back(sslCtx); + sessionCacheManagers_.emplace_back(std::move(smanager)); + ticketManagers_.emplace_back(std::move(tmanager)); + return; + } + + // Insert by CN + insertSSLCtxByDomainName(cn->c_str(), cn->length(), sslCtx); + + // Insert by subject alternative name(s) + auto altNames = SSLUtil::getSubjectAltName(x509); + if (altNames) { + for (auto& name : *altNames) { + insertSSLCtxByDomainName(name.c_str(), name.length(), sslCtx); + } + } + + ctxs_.emplace_back(sslCtx); + sessionCacheManagers_.emplace_back(std::move(smanager)); + ticketManagers_.emplace_back(std::move(tmanager)); +} + +void +SSLContextManager::insertSSLCtxByDomainName(const char* dn, size_t len, + shared_ptr sslCtx) { + try { + insertSSLCtxByDomainNameImpl(dn, len, sslCtx); + } catch (const std::runtime_error& ex) { + if (strict_) { + throw ex; + } else { + LOG(ERROR) << ex.what() << " DN=" << dn; + } + } +} +void +SSLContextManager::insertSSLCtxByDomainNameImpl(const char* dn, size_t len, + shared_ptr sslCtx) +{ + VLOG(4) << + folly::stringPrintf("Adding CN/Subject-alternative-name \"%s\" for " + "SNI search", dn); + + // Only support wildcard domains which are prefixed exactly by "*." . + // "*" appearing at other locations is not accepted. + + if (len > 2 && dn[0] == '*') { + if (dn[1] == '.') { + // skip the first '*' + dn++; + len--; + } else { + throw std::runtime_error( + "Invalid wildcard CN/subject-alternative-name \"" + std::string(dn) + "\" " + "(only allow character \".\" after \"*\""); + } + } + + if (len == 1 && *dn == '.') { + throw std::runtime_error("X509 has only '.' in the CN or subject alternative name " + "(after removing any preceding '*')"); + } + + if (strchr(dn, '*')) { + throw std::runtime_error("X509 has '*' in the the CN or subject alternative name " + "(after removing any preceding '*')"); + } + + DNString dnstr(dn, len); + const auto v = dnMap_.find(dnstr); + if (v == dnMap_.end()) { + dnMap_.emplace(dnstr, sslCtx); + } else if (v->second == sslCtx) { + VLOG(6)<< "Duplicate CN or subject alternative name found in the same X509." + " Ignore the later name."; + } else { + throw std::runtime_error("Duplicate CN or subject alternative name found: \"" + + std::string(dnstr.c_str()) + "\""); + } +} + +shared_ptr +SSLContextManager::getSSLCtxBySuffix(const DNString& dnstr) const +{ + size_t dot; + + if ((dot = dnstr.find_first_of(".")) != DNString::npos) { + DNString suffixDNStr(dnstr, dot); + const auto v = dnMap_.find(suffixDNStr); + if (v != dnMap_.end()) { + VLOG(6) << folly::stringPrintf("\"%s\" is a willcard match to \"%s\"", + dnstr.c_str(), suffixDNStr.c_str()); + return v->second; + } + } + + VLOG(6) << folly::stringPrintf("\"%s\" is not a wildcard match", + dnstr.c_str()); + return shared_ptr(); +} + +shared_ptr +SSLContextManager::getSSLCtx(const DNString& dnstr) const +{ + const auto v = dnMap_.find(dnstr); + if (v == dnMap_.end()) { + VLOG(6) << folly::stringPrintf("\"%s\" is not an exact match", + dnstr.c_str()); + return shared_ptr(); + } else { + VLOG(6) << folly::stringPrintf("\"%s\" is an exact match", dnstr.c_str()); + return v->second; + } +} + +shared_ptr +SSLContextManager::getDefaultSSLCtx() const { + return defaultCtx_; +} + +void +SSLContextManager::reloadTLSTicketKeys( + const std::vector& oldSeeds, + const std::vector& currentSeeds, + const std::vector& newSeeds) { +#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB + for (auto& tmgr: ticketManagers_) { + tmgr->setTLSTicketKeySeeds(oldSeeds, currentSeeds, newSeeds); + } +#endif +} + +} // namespace diff --git a/folly/wangle/ssl/SSLContextManager.h b/folly/wangle/ssl/SSLContextManager.h new file mode 100644 index 00000000..9c55231b --- /dev/null +++ b/folly/wangle/ssl/SSLContextManager.h @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace folly { + +class SocketAddress; +class SSLContext; +class ClientHelloExtStats; +class SSLCacheOptions; +class SSLStats; +class TLSTicketKeyManager; +class TLSTicketKeySeeds; + +class SSLContextManager { + public: + + explicit SSLContextManager(EventBase* eventBase, + const std::string& vipName, bool strict, + SSLStats* stats); + virtual ~SSLContextManager(); + + /** + * Add a new X509 to SSLContextManager. The details of a X509 + * is passed as a SSLContextConfig object. + * + * @param ctxConfig Details of a X509, its private key, password, etc. + * @param cacheOptions Options for how to do session caching. + * @param ticketSeeds If non-null, the initial ticket key seeds to use. + * @param vipAddress Which VIP are the X509(s) used for? It is only for + * for user friendly log message + * @param externalCache Optional external provider for the session cache; + * may be null + */ + void addSSLContextConfig( + const SSLContextConfig& ctxConfig, + const SSLCacheOptions& cacheOptions, + const TLSTicketKeySeeds* ticketSeeds, + const folly::SocketAddress& vipAddress, + const std::shared_ptr &externalCache); + + /** + * Get the default SSL_CTX for a VIP + */ + std::shared_ptr + getDefaultSSLCtx() const; + + /** + * Search by the _one_ level up subdomain + */ + std::shared_ptr + getSSLCtxBySuffix(const DNString& dnstr) const; + + /** + * Search by the full-string domain name + */ + std::shared_ptr + getSSLCtx(const DNString& dnstr) const; + + /** + * Insert a SSLContext by domain name. + */ + void insertSSLCtxByDomainName( + const char* dn, + size_t len, + std::shared_ptr sslCtx); + + void insertSSLCtxByDomainNameImpl( + const char* dn, + size_t len, + std::shared_ptr sslCtx); + + void reloadTLSTicketKeys(const std::vector& oldSeeds, + const std::vector& currentSeeds, + const std::vector& newSeeds); + + /** + * SSLContextManager only collects SNI stats now + */ + + void setClientHelloExtStats(ClientHelloExtStats* stats) { + clientHelloTLSExtStats_ = stats; + } + + protected: + virtual void enableAsyncCrypto( + const std::shared_ptr& sslCtx) { + LOG(FATAL) << "Unsupported in base SSLContextManager"; + } + SSLStats* stats_{nullptr}; + + private: + SSLContextManager(const SSLContextManager&) = delete; + + void ctxSetupByOpensslFeature( + std::shared_ptr sslCtx, + const SSLContextConfig& ctxConfig); + + /** + * Callback function from openssl to find the right X509 to + * use during SSL handshake + */ +#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && \ + !defined(OPENSSL_NO_TLSEXT) && \ + defined(SSL_CTRL_SET_TLSEXT_SERVERNAME_CB) +# define PROXYGEN_HAVE_SERVERNAMECALLBACK + SSLContext::ServerNameCallbackResult + serverNameCallback(SSL* ssl); +#endif + + /** + * The following functions help to maintain the data structure for + * domain name matching in SNI. Some notes: + * + * 1. It is a best match. + * + * 2. It allows wildcard CN and wildcard subject alternative name in a X509. + * The wildcard name must be _prefixed_ by '*.'. It errors out whenever + * it sees '*' in any other locations. + * + * 3. It uses one std::unordered_map object to + * do this. For wildcard name like "*.facebook.com", ".facebook.com" + * is used as the key. + * + * 4. After getting tlsext_hostname from the client hello message, it + * will do a full string search first and then try one level up to + * match any wildcard name (if any) in the X509. + * [Note, browser also only looks one level up when matching the requesting + * domain name with the wildcard name in the server X509]. + */ + + void insert( + std::shared_ptr sslCtx, + std::unique_ptr cmanager, + std::unique_ptr tManager, + bool defaultFallback); + + /** + * Container to own the SSLContext, SSLSessionCacheManager and + * TLSTicketKeyManager. + */ + std::vector> ctxs_; + std::vector> + sessionCacheManagers_; + std::vector> ticketManagers_; + + std::shared_ptr defaultCtx_; + + /** + * Container to store the (DomainName -> SSL_CTX) mapping + */ + std::unordered_map< + DNString, + std::shared_ptr, + DNStringHash> dnMap_; + + EventBase* eventBase_; + ClientHelloExtStats* clientHelloTLSExtStats_{nullptr}; + SSLContextConfig::SNINoMatchFn noMatchFn_; + bool strict_{true}; +}; + +} // namespace diff --git a/folly/wangle/ssl/SSLSessionCacheManager.cpp b/folly/wangle/ssl/SSLSessionCacheManager.cpp new file mode 100644 index 00000000..ec8ccb2d --- /dev/null +++ b/folly/wangle/ssl/SSLSessionCacheManager.cpp @@ -0,0 +1,354 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include + +#include +#include +#include + +#include + +#ifndef NO_LIB_GFLAGS +#include +#endif + +using std::string; +using std::shared_ptr; + +namespace { + +const uint32_t NUM_CACHE_BUCKETS = 16; + +// We use the default ID generator which fills the maximum ID length +// for the protocol. 16 bytes for SSLv2 or 32 for SSLv3+ +const int MIN_SESSION_ID_LENGTH = 16; + +} + +#ifndef NO_LIB_GFLAGS +DEFINE_bool(dcache_unit_test, false, "All VIPs share one session cache"); +#else +const bool FLAGS_dcache_unit_test = false; +#endif + +namespace folly { + + +int SSLSessionCacheManager::sExDataIndex_ = -1; +shared_ptr SSLSessionCacheManager::sCache_; +std::mutex SSLSessionCacheManager::sCacheLock_; + +LocalSSLSessionCache::LocalSSLSessionCache(uint32_t maxCacheSize, + uint32_t cacheCullSize) + : sessionCache(maxCacheSize, cacheCullSize) { + sessionCache.setPruneHook(std::bind( + &LocalSSLSessionCache::pruneSessionCallback, + this, std::placeholders::_1, + std::placeholders::_2)); +} + +void LocalSSLSessionCache::pruneSessionCallback(const string& sessionId, + SSL_SESSION* session) { + VLOG(4) << "Free SSL session from local cache; id=" + << SSLUtil::hexlify(sessionId); + SSL_SESSION_free(session); + ++removedSessions_; +} + + +// SSLSessionCacheManager implementation + +SSLSessionCacheManager::SSLSessionCacheManager( + uint32_t maxCacheSize, + uint32_t cacheCullSize, + SSLContext* ctx, + const folly::SocketAddress& sockaddr, + const string& context, + EventBase* eventBase, + SSLStats* stats, + const std::shared_ptr& externalCache): + ctx_(ctx), + stats_(stats), + externalCache_(externalCache) { + + SSL_CTX* sslCtx = ctx->getSSLCtx(); + + SSLUtil::getSSLCtxExIndex(&sExDataIndex_); + + SSL_CTX_set_ex_data(sslCtx, sExDataIndex_, this); + SSL_CTX_sess_set_new_cb(sslCtx, SSLSessionCacheManager::newSessionCallback); + SSL_CTX_sess_set_get_cb(sslCtx, SSLSessionCacheManager::getSessionCallback); + SSL_CTX_sess_set_remove_cb(sslCtx, + SSLSessionCacheManager::removeSessionCallback); + if (!FLAGS_dcache_unit_test && !context.empty()) { + // Use the passed in context + SSL_CTX_set_session_id_context(sslCtx, (const uint8_t *)context.data(), + std::min((int)context.length(), + SSL_MAX_SSL_SESSION_ID_LENGTH)); + } + + SSL_CTX_set_session_cache_mode(sslCtx, SSL_SESS_CACHE_NO_INTERNAL + | SSL_SESS_CACHE_SERVER); + + localCache_ = SSLSessionCacheManager::getLocalCache(maxCacheSize, + cacheCullSize); + + VLOG(2) << "On VipID=" << sockaddr.describe() << " context=" << context; +} + +SSLSessionCacheManager::~SSLSessionCacheManager() { +} + +void SSLSessionCacheManager::shutdown() { + std::lock_guard g(sCacheLock_); + sCache_.reset(); +} + +shared_ptr SSLSessionCacheManager::getLocalCache( + uint32_t maxCacheSize, + uint32_t cacheCullSize) { + + std::lock_guard g(sCacheLock_); + if (!sCache_) { + sCache_.reset(new ShardedLocalSSLSessionCache(NUM_CACHE_BUCKETS, + maxCacheSize, + cacheCullSize)); + } + return sCache_; +} + +int SSLSessionCacheManager::newSessionCallback(SSL* ssl, SSL_SESSION* session) { + SSLSessionCacheManager* manager = nullptr; + SSL_CTX* ctx = SSL_get_SSL_CTX(ssl); + manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_); + + if (manager == nullptr) { + LOG(FATAL) << "Null SSLSessionCacheManager in callback"; + return -1; + } + return manager->newSession(ssl, session); +} + + +int SSLSessionCacheManager::newSession(SSL* ssl, SSL_SESSION* session) { + string sessionId((char*)session->session_id, session->session_id_length); + VLOG(4) << "New SSL session; id=" << SSLUtil::hexlify(sessionId); + + if (stats_) { + stats_->recordSSLSession(true /* new session */, false, false); + } + + localCache_->storeSession(sessionId, session, stats_); + + if (externalCache_) { + VLOG(4) << "New SSL session: send session to external cache; id=" << + SSLUtil::hexlify(sessionId); + storeCacheRecord(sessionId, session); + } + + return 1; +} + +void SSLSessionCacheManager::removeSessionCallback(SSL_CTX* ctx, + SSL_SESSION* session) { + SSLSessionCacheManager* manager = nullptr; + manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_); + + if (manager == nullptr) { + LOG(FATAL) << "Null SSLSessionCacheManager in callback"; + return; + } + return manager->removeSession(ctx, session); +} + +void SSLSessionCacheManager::removeSession(SSL_CTX* ctx, + SSL_SESSION* session) { + string sessionId((char*)session->session_id, session->session_id_length); + + // This hook is only called from SSL when the internal session cache needs to + // flush sessions. Since we run with the internal cache disabled, this should + // never be called + VLOG(3) << "Remove SSL session; id=" << SSLUtil::hexlify(sessionId); + + localCache_->removeSession(sessionId); + + if (stats_) { + stats_->recordSSLSessionRemove(); + } +} + +SSL_SESSION* SSLSessionCacheManager::getSessionCallback(SSL* ssl, + unsigned char* sess_id, + int id_len, + int* copyflag) { + SSLSessionCacheManager* manager = nullptr; + SSL_CTX* ctx = SSL_get_SSL_CTX(ssl); + manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_); + + if (manager == nullptr) { + LOG(FATAL) << "Null SSLSessionCacheManager in callback"; + return nullptr; + } + return manager->getSession(ssl, sess_id, id_len, copyflag); +} + +SSL_SESSION* SSLSessionCacheManager::getSession(SSL* ssl, + unsigned char* session_id, + int id_len, + int* copyflag) { + VLOG(7) << "SSL get session callback"; + SSL_SESSION* session = nullptr; + bool foreign = false; + char const* missReason = nullptr; + + if (id_len < MIN_SESSION_ID_LENGTH) { + // We didn't generate this session so it's going to be a miss. + // This doesn't get logged or counted in the stats. + return nullptr; + } + string sessionId((char*)session_id, id_len); + + AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl); + + assert(sslSocket != nullptr); + + // look it up in the local cache first + session = localCache_->lookupSession(sessionId); +#ifdef SSL_SESSION_CB_WOULD_BLOCK + if (session == nullptr && externalCache_) { + // external cache might have the session + foreign = true; + if (!SSL_want_sess_cache_lookup(ssl)) { + missReason = "reason: No async cache support;"; + } else { + PendingLookupMap::iterator pit = pendingLookups_.find(sessionId); + if (pit == pendingLookups_.end()) { + auto result = pendingLookups_.emplace(sessionId, PendingLookup()); + // initiate fetch + VLOG(4) << "Get SSL session [Pending]: Initiate Fetch; fd=" << + sslSocket->getFd() << " id=" << SSLUtil::hexlify(sessionId); + if (lookupCacheRecord(sessionId, sslSocket)) { + // response is pending + *copyflag = SSL_SESSION_CB_WOULD_BLOCK; + return nullptr; + } else { + missReason = "reason: failed to send lookup request;"; + pendingLookups_.erase(result.first); + } + } else { + // A lookup was already initiated from this thread + if (pit->second.request_in_progress) { + // Someone else initiated the request, attach + VLOG(4) << "Get SSL session [Pending]: Request in progess: attach; " + "fd=" << sslSocket->getFd() << " id=" << + SSLUtil::hexlify(sessionId); + std::unique_ptr dg( + new DelayedDestruction::DestructorGuard(sslSocket)); + pit->second.waiters.push_back( + std::make_pair(sslSocket, std::move(dg))); + *copyflag = SSL_SESSION_CB_WOULD_BLOCK; + return nullptr; + } + // request is complete + session = pit->second.session; // nullptr if our friend didn't have it + if (session != nullptr) { + CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION); + } + } + } + } +#endif + + bool hit = (session != nullptr); + if (stats_) { + stats_->recordSSLSession(false, hit, foreign); + } + if (hit) { + sslSocket->setSessionIDResumed(true); + } + + VLOG(4) << "Get SSL session [" << + ((hit) ? "Hit" : "Miss") << "]: " << + ((foreign) ? "external" : "local") << " cache; " << + ((missReason != nullptr) ? missReason : "") << "fd=" << + sslSocket->getFd() << " id=" << SSLUtil::hexlify(sessionId); + + // We already bumped the refcount + *copyflag = 0; + + return session; +} + +bool SSLSessionCacheManager::storeCacheRecord(const string& sessionId, + SSL_SESSION* session) { + std::string sessionString; + uint32_t sessionLen = i2d_SSL_SESSION(session, nullptr); + sessionString.resize(sessionLen); + uint8_t* cp = (uint8_t *)sessionString.data(); + i2d_SSL_SESSION(session, &cp); + size_t expiration = SSL_CTX_get_timeout(ctx_->getSSLCtx()); + return externalCache_->setAsync(sessionId, sessionString, + std::chrono::seconds(expiration)); +} + +bool SSLSessionCacheManager::lookupCacheRecord(const string& sessionId, + AsyncSSLSocket* sslSocket) { + auto cacheCtx = new SSLCacheProvider::CacheContext(); + cacheCtx->sessionId = sessionId; + cacheCtx->session = nullptr; + cacheCtx->sslSocket = sslSocket; + cacheCtx->guard.reset( + new DelayedDestruction::DestructorGuard(cacheCtx->sslSocket)); + cacheCtx->manager = this; + bool res = externalCache_->getAsync(sessionId, cacheCtx); + if (!res) { + delete cacheCtx; + } + return res; +} + +void SSLSessionCacheManager::restartSSLAccept( + const SSLCacheProvider::CacheContext* cacheCtx) { + PendingLookupMap::iterator pit = pendingLookups_.find(cacheCtx->sessionId); + CHECK(pit != pendingLookups_.end()); + pit->second.request_in_progress = false; + pit->second.session = cacheCtx->session; + VLOG(7) << "Restart SSL accept"; + cacheCtx->sslSocket->restartSSLAccept(); + for (const auto& attachedLookup: pit->second.waiters) { + // Wake up anyone else who was waiting for this session + VLOG(4) << "Restart SSL accept (waiters) for fd=" << + attachedLookup.first->getFd(); + attachedLookup.first->restartSSLAccept(); + } + pendingLookups_.erase(pit); +} + +void SSLSessionCacheManager::onGetSuccess( + SSLCacheProvider::CacheContext* cacheCtx, + const std::string& value) { + const uint8_t* cp = (uint8_t*)value.data(); + cacheCtx->session = d2i_SSL_SESSION(nullptr, &cp, value.length()); + restartSSLAccept(cacheCtx); + + /* Insert in the LRU after restarting all clients. The stats logic + * in getSession would treat this as a local hit otherwise. + */ + localCache_->storeSession(cacheCtx->sessionId, cacheCtx->session, stats_); + delete cacheCtx; +} + +void SSLSessionCacheManager::onGetFailure( + SSLCacheProvider::CacheContext* cacheCtx) { + restartSSLAccept(cacheCtx); + delete cacheCtx; +} + +} // namespace diff --git a/folly/wangle/ssl/SSLSessionCacheManager.h b/folly/wangle/ssl/SSLSessionCacheManager.h new file mode 100644 index 00000000..4b1e55c1 --- /dev/null +++ b/folly/wangle/ssl/SSLSessionCacheManager.h @@ -0,0 +1,292 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include +#include + +#include +#include +#include + +namespace folly { + +class SSLStats; + +/** + * Basic SSL session cache map: Maps session id -> session + */ +typedef folly::EvictingCacheMap SSLSessionCacheMap; + +/** + * Holds an SSLSessionCacheMap and associated lock + */ +class LocalSSLSessionCache: private boost::noncopyable { + public: + LocalSSLSessionCache(uint32_t maxCacheSize, uint32_t cacheCullSize); + + ~LocalSSLSessionCache() { + std::lock_guard g(lock); + // EvictingCacheMap dtor doesn't free values + sessionCache.clear(); + } + + SSLSessionCacheMap sessionCache; + std::mutex lock; + uint32_t removedSessions_{0}; + + private: + + void pruneSessionCallback(const std::string& sessionId, + SSL_SESSION* session); +}; + +/** + * A sharded LRU for SSL sessions. The sharding is inteneded to reduce + * contention for the LRU locks. Assuming uniform distribution, two workers + * will contend for the same lock with probability 1 / n_buckets^2. + */ +class ShardedLocalSSLSessionCache : private boost::noncopyable { + public: + ShardedLocalSSLSessionCache(uint32_t n_buckets, uint32_t maxCacheSize, + uint32_t cacheCullSize) { + CHECK(n_buckets > 0); + maxCacheSize = (uint32_t)(((double)maxCacheSize) / n_buckets); + cacheCullSize = (uint32_t)(((double)cacheCullSize) / n_buckets); + if (maxCacheSize == 0) { + maxCacheSize = 1; + } + if (cacheCullSize == 0) { + cacheCullSize = 1; + } + for (uint32_t i = 0; i < n_buckets; i++) { + caches_.push_back( + std::unique_ptr( + new LocalSSLSessionCache(maxCacheSize, cacheCullSize))); + } + } + + SSL_SESSION* lookupSession(const std::string& sessionId) { + size_t bucket = hash(sessionId); + SSL_SESSION* session = nullptr; + std::lock_guard g(caches_[bucket]->lock); + + auto itr = caches_[bucket]->sessionCache.find(sessionId); + if (itr != caches_[bucket]->sessionCache.end()) { + session = itr->second; + } + + if (session) { + CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION); + } + return session; + } + + void storeSession(const std::string& sessionId, SSL_SESSION* session, + SSLStats* stats) { + size_t bucket = hash(sessionId); + SSL_SESSION* oldSession = nullptr; + std::lock_guard g(caches_[bucket]->lock); + + auto itr = caches_[bucket]->sessionCache.find(sessionId); + if (itr != caches_[bucket]->sessionCache.end()) { + oldSession = itr->second; + } + + if (oldSession) { + // LRUCacheMap doesn't free on overwrite, so 2x the work for us + // This can happen in race conditions + SSL_SESSION_free(oldSession); + } + caches_[bucket]->removedSessions_ = 0; + caches_[bucket]->sessionCache.set(sessionId, session, true); + if (stats) { + stats->recordSSLSessionFree(caches_[bucket]->removedSessions_); + } + } + + void removeSession(const std::string& sessionId) { + size_t bucket = hash(sessionId); + std::lock_guard g(caches_[bucket]->lock); + caches_[bucket]->sessionCache.erase(sessionId); + } + + private: + + /* SSL session IDs are 32 bytes of random data, hash based on first 16 bits */ + size_t hash(const std::string& key) { + CHECK(key.length() >= 2); + return (key[0] << 8 | key[1]) % caches_.size(); + } + + std::vector< std::unique_ptr > caches_; +}; + +/* A socket/DestructorGuard pair */ +typedef std::pair> + AttachedLookup; + +/** + * PendingLookup structure + * + * Keeps track of clients waiting for an SSL session to be retrieved from + * the external cache provider. + */ +struct PendingLookup { + bool request_in_progress; + SSL_SESSION* session; + std::list waiters; + + PendingLookup() { + request_in_progress = true; + session = nullptr; + } +}; + +/* Maps SSL session id to a PendingLookup structure */ +typedef std::map PendingLookupMap; + +/** + * SSLSessionCacheManager handles all stateful session caching. There is an + * instance of this object per SSL VIP per thread, with a 1:1 correlation with + * SSL_CTX. The cache can work locally or in concert with an external cache + * to share sessions across instances. + * + * There is a single in memory session cache shared by all VIPs. The cache is + * split into N buckets (currently 16) with a separate lock per bucket. The + * VIP ID is hashed and stored as part of the session to handle the + * (very unlikely) case of session ID collision. + * + * When a new SSL session is created, it is added to the LRU cache and + * sent to the external cache to be stored. The external cache + * expiration is equal to the SSL session's expiration. + * + * When a resume request is received, SSLSessionCacheManager first looks in the + * local LRU cache for the VIP. If there is a miss there, an asynchronous + * request for this session is dispatched to the external cache. When the + * external cache query returns, the LRU cache is updated if the session was + * found, and the SSL_accept call is resumed. + * + * If additional resume requests for the same session ID arrive in the same + * thread while the request is pending, the 2nd - Nth callers attach to the + * original external cache requests and are resumed when it comes back. No + * attempt is made to coalesce external cache requests for the same session + * ID in different worker threads. Previous work did this, but the + * complexity was deemed to outweigh the potential savings. + * + */ +class SSLSessionCacheManager : private boost::noncopyable { + public: + /** + * Constructor. SSL session related callbacks will be set on the underlying + * SSL_CTX. vipId is assumed to a unique string identifying the VIP and must + * be the same on all servers that wish to share sessions via the same + * external cache. + */ + SSLSessionCacheManager( + uint32_t maxCacheSize, + uint32_t cacheCullSize, + SSLContext* ctx, + const folly::SocketAddress& sockaddr, + const std::string& context, + EventBase* eventBase, + SSLStats* stats, + const std::shared_ptr& externalCache); + + virtual ~SSLSessionCacheManager(); + + /** + * Call this on shutdown to release the global instance of the + * ShardedLocalSSLSessionCache. + */ + static void shutdown(); + + /** + * Callback for ExternalCache to call when an async get succeeds + * @param context The context that was passed to the async get request + * @param value Serialized session + */ + void onGetSuccess(SSLCacheProvider::CacheContext* context, + const std::string& value); + + /** + * Callback for ExternalCache to call when an async get fails, either + * because the requested session is not in the external cache or because + * of an error. + * @param context The context that was passed to the async get request + */ + void onGetFailure(SSLCacheProvider::CacheContext* context); + + private: + + SSLContext* ctx_; + std::shared_ptr localCache_; + PendingLookupMap pendingLookups_; + SSLStats* stats_{nullptr}; + std::shared_ptr externalCache_; + + /** + * Invoked by openssl when a new SSL session is created + */ + int newSession(SSL* ssl, SSL_SESSION* session); + + /** + * Invoked by openssl when an SSL session is ejected from its internal cache. + * This can't be invoked in the current implementation because SSL's internal + * caching is disabled. + */ + void removeSession(SSL_CTX* ctx, SSL_SESSION* session); + + /** + * Invoked by openssl when a client requests a stateful session resumption. + * Triggers a lookup in our local cache and potentially an asynchronous + * request to an external cache. + */ + SSL_SESSION* getSession(SSL* ssl, unsigned char* session_id, + int id_len, int* copyflag); + + /** + * Store a new session record in the external cache + */ + bool storeCacheRecord(const std::string& sessionId, SSL_SESSION* session); + + /** + * Lookup a session in the external cache for the specified SSL socket. + */ + bool lookupCacheRecord(const std::string& sessionId, + AsyncSSLSocket* sslSock); + + /** + * Restart all clients waiting for the answer to an external cache query + */ + void restartSSLAccept(const SSLCacheProvider::CacheContext* cacheCtx); + + /** + * Get or create the LRU cache for the given VIP ID + */ + static std::shared_ptr getLocalCache( + uint32_t maxCacheSize, uint32_t cacheCullSize); + + /** + * static functions registered as callbacks to openssl via + * SSL_CTX_sess_set_new/get/remove_cb + */ + static int newSessionCallback(SSL* ssl, SSL_SESSION* session); + static void removeSessionCallback(SSL_CTX* ctx, SSL_SESSION* session); + static SSL_SESSION* getSessionCallback(SSL* ssl, unsigned char* session_id, + int id_len, int* copyflag); + + static int32_t sExDataIndex_; + static std::shared_ptr sCache_; + static std::mutex sCacheLock_; +}; + +} diff --git a/folly/wangle/ssl/SSLStats.h b/folly/wangle/ssl/SSLStats.h new file mode 100644 index 00000000..761a8434 --- /dev/null +++ b/folly/wangle/ssl/SSLStats.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +namespace folly { + +class SSLStats { + public: + virtual ~SSLStats() noexcept {} + + // downstream + virtual void recordSSLAcceptLatency(int64_t latency) noexcept = 0; + virtual void recordTLSTicket(bool ticketNew, bool ticketHit) noexcept = 0; + virtual void recordSSLSession(bool sessionNew, bool sessionHit, bool foreign) + noexcept = 0; + virtual void recordSSLSessionRemove() noexcept = 0; + virtual void recordSSLSessionFree(uint32_t freed) noexcept = 0; + virtual void recordSSLSessionSetError(uint32_t err) noexcept = 0; + virtual void recordSSLSessionGetError(uint32_t err) noexcept = 0; + virtual void recordClientRenegotiation() noexcept = 0; + + // upstream + virtual void recordSSLUpstreamConnection(bool handshake) noexcept = 0; + virtual void recordSSLUpstreamConnectionError(bool verifyError) noexcept = 0; + virtual void recordCryptoSSLExternalAttempt() noexcept = 0; + virtual void recordCryptoSSLExternalConnAlreadyClosed() noexcept = 0; + virtual void recordCryptoSSLExternalApplicationException() noexcept = 0; + virtual void recordCryptoSSLExternalSuccess() noexcept = 0; + virtual void recordCryptoSSLExternalDuration(uint64_t duration) noexcept = 0; + virtual void recordCryptoSSLLocalAttempt() noexcept = 0; + virtual void recordCryptoSSLLocalSuccess() noexcept = 0; + +}; + +} diff --git a/folly/wangle/ssl/SSLUtil.cpp b/folly/wangle/ssl/SSLUtil.cpp new file mode 100644 index 00000000..85056856 --- /dev/null +++ b/folly/wangle/ssl/SSLUtil.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include + +#include + +#if OPENSSL_VERSION_NUMBER >= 0x1000105fL +#define OPENSSL_GE_101 1 +#include +#include +#else +#undef OPENSSL_GE_101 +#endif + +namespace folly { + +std::mutex SSLUtil::sIndexLock_; + +std::unique_ptr SSLUtil::getCommonName(const X509* cert) { + X509_NAME* subject = X509_get_subject_name((X509*)cert); + if (!subject) { + return nullptr; + } + char cn[ub_common_name + 1]; + int res = X509_NAME_get_text_by_NID(subject, NID_commonName, + cn, ub_common_name); + if (res <= 0) { + return nullptr; + } else { + cn[ub_common_name] = '\0'; + return folly::make_unique(cn); + } +} + +std::unique_ptr> SSLUtil::getSubjectAltName( + const X509* cert) { +#ifdef OPENSSL_GE_101 + auto nameList = folly::make_unique>(); + GENERAL_NAMES* names = (GENERAL_NAMES*)X509_get_ext_d2i( + (X509*)cert, NID_subject_alt_name, nullptr, nullptr); + if (names) { + auto guard = folly::makeGuard([names] { GENERAL_NAMES_free(names); }); + size_t count = sk_GENERAL_NAME_num(names); + CHECK(count < std::numeric_limits::max()); + for (int i = 0; i < (int)count; ++i) { + GENERAL_NAME* generalName = sk_GENERAL_NAME_value(names, i); + if (generalName->type == GEN_DNS) { + ASN1_STRING* s = generalName->d.dNSName; + const char* name = (const char*)ASN1_STRING_data(s); + // I can't find any docs on what a negative return value here + // would mean, so I'm going to ignore it. + auto len = ASN1_STRING_length(s); + DCHECK(len >= 0); + if (size_t(len) != strlen(name)) { + // Null byte(s) in the name; return an error rather than depending on + // the caller to safely handle this case. + return nullptr; + } + nameList->emplace_back(name); + } + } + } + return nameList; +#else + return nullptr; +#endif +} + +} diff --git a/folly/wangle/ssl/SSLUtil.h b/folly/wangle/ssl/SSLUtil.h new file mode 100644 index 00000000..20a17a95 --- /dev/null +++ b/folly/wangle/ssl/SSLUtil.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include +#include +#include + +namespace folly { + +/** + * SSL session establish/resume status + * + * changing these values will break logging pipelines + */ +enum class SSLResumeEnum : uint8_t { + HANDSHAKE = 0, + RESUME_SESSION_ID = 1, + RESUME_TICKET = 3, + NA = 2 +}; + +enum class SSLErrorEnum { + NO_ERROR, + TIMEOUT, + DROPPED +}; + +class SSLUtil { + private: + static std::mutex sIndexLock_; + + public: + /** + * Ensures only one caller will allocate an ex_data index for a given static + * or global. + */ + static void getSSLCtxExIndex(int* pindex) { + std::lock_guard g(sIndexLock_); + if (*pindex < 0) { + *pindex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); + } + } + + static void getRSAExIndex(int* pindex) { + std::lock_guard g(sIndexLock_); + if (*pindex < 0) { + *pindex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); + } + } + + static inline std::string hexlify(const std::string& binary) { + std::string hex; + folly::hexlify(binary, hex); + + return hex; + } + + static inline const std::string& hexlify(const std::string& binary, + std::string& hex) { + folly::hexlify(binary, hex); + + return hex; + } + + /** + * Return the SSL resume type for the given socket. + */ + static inline SSLResumeEnum getResumeState( + AsyncSSLSocket* sslSocket) { + return sslSocket->getSSLSessionReused() ? + (sslSocket->sessionIDResumed() ? + SSLResumeEnum::RESUME_SESSION_ID : + SSLResumeEnum::RESUME_TICKET) : + SSLResumeEnum::HANDSHAKE; + } + + /** + * Get the Common Name from an X.509 certificate + * @param cert certificate to inspect + * @return common name, or null if an error occurs + */ + static std::unique_ptr getCommonName(const X509* cert); + + /** + * Get the Subject Alternative Name value(s) from an X.509 certificate + * @param cert certificate to inspect + * @return set of zero or more alternative names, or null if + * an error occurs + */ + static std::unique_ptr> getSubjectAltName( + const X509* cert); +}; + +} diff --git a/folly/wangle/ssl/TLSTicketKeyManager.cpp b/folly/wangle/ssl/TLSTicketKeyManager.cpp new file mode 100644 index 00000000..1f74add6 --- /dev/null +++ b/folly/wangle/ssl/TLSTicketKeyManager.cpp @@ -0,0 +1,305 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include + +#include +#include + +#include +#include +#include +#include +#include + +#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB +using std::string; + +namespace { + +const int kTLSTicketKeyNameLen = 4; +const int kTLSTicketKeySaltLen = 12; + +} + +namespace folly { + + +// TLSTicketKeyManager Implementation +int32_t TLSTicketKeyManager::sExDataIndex_ = -1; + +TLSTicketKeyManager::TLSTicketKeyManager(SSLContext* ctx, SSLStats* stats) + : ctx_(ctx), + randState_(0), + stats_(stats) { + SSLUtil::getSSLCtxExIndex(&sExDataIndex_); + SSL_CTX_set_ex_data(ctx_->getSSLCtx(), sExDataIndex_, this); +} + +TLSTicketKeyManager::~TLSTicketKeyManager() { +} + +int +TLSTicketKeyManager::callback(SSL* ssl, unsigned char* keyName, + unsigned char* iv, + EVP_CIPHER_CTX* cipherCtx, + HMAC_CTX* hmacCtx, int encrypt) { + TLSTicketKeyManager* manager = nullptr; + SSL_CTX* ctx = SSL_get_SSL_CTX(ssl); + manager = (TLSTicketKeyManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_); + + if (manager == nullptr) { + LOG(FATAL) << "Null TLSTicketKeyManager in callback" ; + return -1; + } + return manager->processTicket(ssl, keyName, iv, cipherCtx, hmacCtx, encrypt); +} + +int +TLSTicketKeyManager::processTicket(SSL* ssl, unsigned char* keyName, + unsigned char* iv, + EVP_CIPHER_CTX* cipherCtx, + HMAC_CTX* hmacCtx, int encrypt) { + uint8_t salt[kTLSTicketKeySaltLen]; + uint8_t* saltptr = nullptr; + uint8_t output[SHA256_DIGEST_LENGTH]; + uint8_t* hmacKey = nullptr; + uint8_t* aesKey = nullptr; + TLSTicketKeySource* key = nullptr; + int result = 0; + + if (encrypt) { + key = findEncryptionKey(); + if (key == nullptr) { + // no keys available to encrypt + VLOG(2) << "No TLS ticket key found"; + return -1; + } + VLOG(4) << "Encrypting new ticket with key name=" << + SSLUtil::hexlify(key->keyName_); + + // Get a random salt and write out key name + RAND_pseudo_bytes(salt, (int)sizeof(salt)); + memcpy(keyName, key->keyName_.data(), kTLSTicketKeyNameLen); + memcpy(keyName + kTLSTicketKeyNameLen, salt, kTLSTicketKeySaltLen); + + // Create the unique keys by hashing with the salt + makeUniqueKeys(key->keySource_, sizeof(key->keySource_), salt, output); + // This relies on the fact that SHA256 has 32 bytes of output + // and that AES-128 keys are 16 bytes + hmacKey = output; + aesKey = output + SHA256_DIGEST_LENGTH / 2; + + // Initialize iv and cipher/mac CTX + RAND_pseudo_bytes(iv, AES_BLOCK_SIZE); + HMAC_Init_ex(hmacCtx, hmacKey, SHA256_DIGEST_LENGTH / 2, + EVP_sha256(), nullptr); + EVP_EncryptInit_ex(cipherCtx, EVP_aes_128_cbc(), nullptr, aesKey, iv); + + result = 1; + } else { + key = findDecryptionKey(keyName); + if (key == nullptr) { + // no ticket found for decryption - will issue a new ticket + if (VLOG_IS_ON(4)) { + string skeyName((char *)keyName, kTLSTicketKeyNameLen); + VLOG(4) << "Can't find ticket key with name=" << + SSLUtil::hexlify(skeyName)<< ", will generate new ticket"; + } + + result = 0; + } else { + VLOG(4) << "Decrypting ticket with key name=" << + SSLUtil::hexlify(key->keyName_); + + // Reconstruct the unique key via the salt + saltptr = keyName + kTLSTicketKeyNameLen; + makeUniqueKeys(key->keySource_, sizeof(key->keySource_), saltptr, output); + hmacKey = output; + aesKey = output + SHA256_DIGEST_LENGTH / 2; + + // Initialize cipher/mac CTX + HMAC_Init_ex(hmacCtx, hmacKey, SHA256_DIGEST_LENGTH / 2, + EVP_sha256(), nullptr); + EVP_DecryptInit_ex(cipherCtx, EVP_aes_128_cbc(), nullptr, aesKey, iv); + + result = 1; + } + } + // result records whether a ticket key was found to decrypt this ticket, + // not wether the session was re-used. + if (stats_) { + stats_->recordTLSTicket(encrypt, result); + } + + return result; +} + +bool +TLSTicketKeyManager::setTLSTicketKeySeeds( + const std::vector& oldSeeds, + const std::vector& currentSeeds, + const std::vector& newSeeds) { + + bool result = true; + + activeKeys_.clear(); + ticketKeys_.clear(); + ticketSeeds_.clear(); + const std::vector *seedList = &oldSeeds; + for (uint32_t i = 0; i < 3; i++) { + TLSTicketSeedType type = (TLSTicketSeedType)i; + if (type == SEED_CURRENT) { + seedList = ¤tSeeds; + } else if (type == SEED_NEW) { + seedList = &newSeeds; + } + + for (const auto& seedInput: *seedList) { + TLSTicketSeed* seed = insertSeed(seedInput, type); + if (seed == nullptr) { + result = false; + continue; + } + insertNewKey(seed, 1, nullptr); + } + } + if (!result) { + VLOG(2) << "One or more seeds failed to decode"; + } + + if (ticketKeys_.size() == 0 || activeKeys_.size() == 0) { + LOG(WARNING) << "No keys configured, falling back to default"; + SSL_CTX_set_tlsext_ticket_key_cb(ctx_->getSSLCtx(), nullptr); + return false; + } + SSL_CTX_set_tlsext_ticket_key_cb(ctx_->getSSLCtx(), + TLSTicketKeyManager::callback); + + return true; +} + +string +TLSTicketKeyManager::makeKeyName(TLSTicketSeed* seed, uint32_t n, + unsigned char* nameBuf) { + SHA256_CTX ctx; + + SHA256_Init(&ctx); + SHA256_Update(&ctx, seed->seedName_, sizeof(seed->seedName_)); + SHA256_Update(&ctx, &n, sizeof(n)); + SHA256_Final(nameBuf, &ctx); + return string((char *)nameBuf, kTLSTicketKeyNameLen); +} + +TLSTicketKeyManager::TLSTicketKeySource* +TLSTicketKeyManager::insertNewKey(TLSTicketSeed* seed, uint32_t hashCount, + TLSTicketKeySource* prevKey) { + unsigned char nameBuf[SHA256_DIGEST_LENGTH]; + std::unique_ptr newKey(new TLSTicketKeySource); + + // This function supports hash chaining but it is not currently used. + + if (prevKey != nullptr) { + hashNth(prevKey->keySource_, sizeof(prevKey->keySource_), + newKey->keySource_, 1); + } else { + // can't go backwards or the current is missing, start from the beginning + hashNth((unsigned char *)seed->seed_.data(), seed->seed_.length(), + newKey->keySource_, hashCount); + } + + newKey->hashCount_ = hashCount; + newKey->keyName_ = makeKeyName(seed, hashCount, nameBuf); + newKey->type_ = seed->type_; + auto it = ticketKeys_.insert(std::make_pair(newKey->keyName_, + std::move(newKey))); + + auto key = it.first->second.get(); + if (key->type_ == SEED_CURRENT) { + activeKeys_.push_back(key); + } + VLOG(4) << "Adding key for " << hashCount << " type=" << + (uint32_t)key->type_ << " Name=" << SSLUtil::hexlify(key->keyName_); + + return key; +} + +void +TLSTicketKeyManager::hashNth(const unsigned char* input, size_t input_len, + unsigned char* output, uint32_t n) { + assert(n > 0); + for (uint32_t i = 0; i < n; i++) { + SHA256(input, input_len, output); + input = output; + input_len = SHA256_DIGEST_LENGTH; + } +} + +TLSTicketKeyManager::TLSTicketSeed * +TLSTicketKeyManager::insertSeed(const string& seedInput, + TLSTicketSeedType type) { + TLSTicketSeed* seed = nullptr; + string seedOutput; + + if (!folly::unhexlify(seedInput, seedOutput)) { + LOG(WARNING) << "Failed to decode seed type=" << (uint32_t)type << + " seed=" << seedInput; + return seed; + } + + seed = new TLSTicketSeed(); + seed->seed_ = seedOutput; + seed->type_ = type; + SHA256((unsigned char *)seedOutput.data(), seedOutput.length(), + seed->seedName_); + ticketSeeds_.push_back(std::unique_ptr(seed)); + + return seed; +} + +TLSTicketKeyManager::TLSTicketKeySource * +TLSTicketKeyManager::findEncryptionKey() { + TLSTicketKeySource* result = nullptr; + // call to rand here is a bit hokey since it's not cryptographically + // random, and is predictably seeded with 0. However, activeKeys_ + // is probably not going to have very many keys in it, and most + // likely only 1. + size_t numKeys = activeKeys_.size(); + if (numKeys > 0) { + result = activeKeys_[rand_r(&randState_) % numKeys]; + } + return result; +} + +TLSTicketKeyManager::TLSTicketKeySource * +TLSTicketKeyManager::findDecryptionKey(unsigned char* keyName) { + string name((char *)keyName, kTLSTicketKeyNameLen); + TLSTicketKeySource* key = nullptr; + TLSTicketKeyMap::iterator mapit = ticketKeys_.find(name); + if (mapit != ticketKeys_.end()) { + key = mapit->second.get(); + } + return key; +} + +void +TLSTicketKeyManager::makeUniqueKeys(unsigned char* parentKey, + size_t keyLen, + unsigned char* salt, + unsigned char* output) { + SHA256_CTX hash_ctx; + + SHA256_Init(&hash_ctx); + SHA256_Update(&hash_ctx, parentKey, keyLen); + SHA256_Update(&hash_ctx, salt, kTLSTicketKeySaltLen); + SHA256_Final(output, &hash_ctx); +} + +} // namespace +#endif diff --git a/folly/wangle/ssl/TLSTicketKeyManager.h b/folly/wangle/ssl/TLSTicketKeyManager.h new file mode 100644 index 00000000..4000c139 --- /dev/null +++ b/folly/wangle/ssl/TLSTicketKeyManager.h @@ -0,0 +1,198 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include +#include + +namespace folly { + +#ifndef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB +class TLSTicketKeyManager {}; +#else +class SSLStats; +/** + * The TLSTicketKeyManager handles TLS ticket key encryption and decryption in + * a way that facilitates sharing the ticket keys across a range of servers. + * Hash chaining is employed to achieve frequent key rotation with minimal + * configuration change. The scheme is as follows: + * + * The manager is supplied with three lists of seeds (old, current and new). + * The config should be updated with new seeds periodically (e.g., daily). + * 3 config changes are recommended to achieve the smoothest seed rotation + * eg: + * 1. Introduce new seed in the push prior to rotation + * 2. Rotation push + * 3. Remove old seeds in the push following rotation + * + * Multiple seeds are supported but only a single seed is required. + * + * Generating encryption keys from the seed works as follows. For a given + * seed, hash forward N times where N is currently the constant 1. + * This is the base key. The name of the base key is the first 4 + * bytes of hash(hash(seed), N). This is copied into the first 4 bytes of the + * TLS ticket key name field. + * + * For each new ticket encryption, the manager generates a random 12 byte salt. + * Hash the salt and the base key together to form the encryption key for + * that ticket. The salt is included in the ticket's 'key name' field so it + * can be used to derive the decryption key. The salt is copied into the second + * 8 bytes of the TLS ticket key name field. + * + * A key is valid for decryption for the lifetime of the instance. + * Sessions will be valid for less time than that, which results in an extra + * symmetric decryption to discover the session is expired. + * + * A TLSTicketKeyManager should be used in only one thread, and should have + * a 1:1 relationship with the SSLContext provided. + * + */ +class TLSTicketKeyManager : private boost::noncopyable { + public: + + explicit TLSTicketKeyManager(folly::SSLContext* ctx, + SSLStats* stats); + + virtual ~TLSTicketKeyManager(); + + /** + * SSL callback to set up encryption/decryption context for a TLS Ticket Key. + * + * This will be supplied to the SSL library via + * SSL_CTX_set_tlsext_ticket_key_cb. + */ + static int callback(SSL* ssl, unsigned char* keyName, + unsigned char* iv, + EVP_CIPHER_CTX* cipherCtx, + HMAC_CTX* hmacCtx, int encrypt); + + /** + * Initialize the manager with three sets of seeds. There must be at least + * one current seed, or the manager will revert to the default SSL behavior. + * + * @param oldSeeds Seeds previously used which can still decrypt. + * @param currentSeeds Seeds to use for new ticket encryptions. + * @param newSeeds Seeds which will be used soon, can be used to decrypt + * in case some servers in the cluster have already rotated. + */ + bool setTLSTicketKeySeeds(const std::vector& oldSeeds, + const std::vector& currentSeeds, + const std::vector& newSeeds); + + private: + enum TLSTicketSeedType { + SEED_OLD = 0, + SEED_CURRENT, + SEED_NEW + }; + + /* The seeds supplied by the configuration */ + struct TLSTicketSeed { + std::string seed_; + TLSTicketSeedType type_; + unsigned char seedName_[SHA256_DIGEST_LENGTH]; + }; + + struct TLSTicketKeySource { + int32_t hashCount_; + std::string keyName_; + TLSTicketSeedType type_; + unsigned char keySource_[SHA256_DIGEST_LENGTH]; + }; + + /** + * Method to setup encryption/decryption context for a TLS Ticket Key + * + * OpenSSL documentation is thin on the return value semantics. + * + * For encrypt=1, return < 0 on error, >= 0 for successfully initialized + * For encrypt=0, return < 0 on error, 0 on key not found + * 1 on key found, 2 renew_ticket + * + * renew_ticket means a new ticket will be issued. We could return this value + * when receiving a ticket encrypted with a key derived from an OLD seed. + * However, session_timeout seconds after deploying with a seed + * rotated from CURRENT -> OLD, there will be no valid tickets outstanding + * encrypted with the old key. This grace period means no unnecessary + * handshakes will be performed. If the seed is believed compromised, it + * should NOT be configured as an OLD seed. + */ + int processTicket(SSL* ssl, unsigned char* keyName, + unsigned char* iv, + EVP_CIPHER_CTX* cipherCtx, + HMAC_CTX* hmacCtx, int encrypt); + + // Creates the name for the nth key generated from seed + std::string makeKeyName(TLSTicketSeed* seed, uint32_t n, + unsigned char* nameBuf); + + /** + * Creates the key hashCount hashes from the given seed and inserts it in + * ticketKeys. A naked pointer to the key is returned for additional + * processing if needed. + */ + TLSTicketKeySource* insertNewKey(TLSTicketSeed* seed, uint32_t hashCount, + TLSTicketKeySource* prevKeySource); + + /** + * hashes input N times placing result in output, which must be at least + * SHA256_DIGEST_LENGTH long. + */ + void hashNth(const unsigned char* input, size_t input_len, + unsigned char* output, uint32_t n); + + /** + * Adds the given seed to the manager + */ + TLSTicketSeed* insertSeed(const std::string& seedInput, + TLSTicketSeedType type); + + /** + * Locate a key for encrypting a new ticket + */ + TLSTicketKeySource* findEncryptionKey(); + + /** + * Locate a key for decrypting a ticket with the given keyName + */ + TLSTicketKeySource* findDecryptionKey(unsigned char* keyName); + + /** + * Derive a unique key from the parent key and the salt via hashing + */ + void makeUniqueKeys(unsigned char* parentKey, size_t keyLen, + unsigned char* salt, unsigned char* output); + + /** + * For standalone decryption utility + */ + friend int decrypt_fb_ticket(folly::TLSTicketKeyManager* manager, + const std::string& testTicket, + SSL_SESSION **psess); + + typedef std::vector> TLSTicketSeedList; + typedef std::map > + TLSTicketKeyMap; + typedef std::vector TLSActiveKeyList; + + TLSTicketSeedList ticketSeeds_; + // All key sources that can be used for decryption + TLSTicketKeyMap ticketKeys_; + // Key sources that can be used for encryption + TLSActiveKeyList activeKeys_; + + folly::SSLContext* ctx_; + uint32_t randState_; + SSLStats* stats_{nullptr}; + + static int32_t sExDataIndex_; +}; +#endif +} diff --git a/folly/wangle/ssl/TLSTicketKeySeeds.h b/folly/wangle/ssl/TLSTicketKeySeeds.h new file mode 100644 index 00000000..c40ae581 --- /dev/null +++ b/folly/wangle/ssl/TLSTicketKeySeeds.h @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +namespace folly { + +struct TLSTicketKeySeeds { + std::vector oldSeeds; + std::vector currentSeeds; + std::vector newSeeds; +}; + +} diff --git a/folly/wangle/ssl/test/SSLCacheTest.cpp b/folly/wangle/ssl/test/SSLCacheTest.cpp new file mode 100644 index 00000000..2433cfc0 --- /dev/null +++ b/folly/wangle/ssl/test/SSLCacheTest.cpp @@ -0,0 +1,272 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace folly; + +DEFINE_int32(clients, 1, "Number of simulated SSL clients"); +DEFINE_int32(threads, 1, "Number of threads to spread clients across"); +DEFINE_int32(requests, 2, "Total number of requests per client"); +DEFINE_int32(port, 9423, "Server port"); +DEFINE_bool(sticky, false, "A given client sends all reqs to one " + "(random) server"); +DEFINE_bool(global, false, "All clients in a thread use the same SSL session"); +DEFINE_bool(handshakes, false, "Force 100% handshakes"); + +string f_servers[10]; +int f_num_servers = 0; +int tnum = 0; + +class ClientRunner { + public: + + ClientRunner(): reqs(0), hits(0), miss(0), num(tnum++) {} + void run(); + + int reqs; + int hits; + int miss; + int num; +}; + +class SSLCacheClient : public AsyncSocket::ConnectCallback, + public AsyncSSLSocket::HandshakeCB +{ +private: + EventBase* eventBase_; + int currReq_; + int serverIdx_; + AsyncSocket* socket_; + AsyncSSLSocket* sslSocket_; + SSL_SESSION* session_; + SSL_SESSION **pSess_; + std::shared_ptr ctx_; + ClientRunner* cr_; + +public: + SSLCacheClient(EventBase* eventBase, SSL_SESSION **pSess, ClientRunner* cr); + ~SSLCacheClient() { + if (session_ && !FLAGS_global) + SSL_SESSION_free(session_); + if (socket_ != nullptr) { + if (sslSocket_ != nullptr) { + sslSocket_->destroy(); + sslSocket_ = nullptr; + } + socket_->destroy(); + socket_ = nullptr; + } + }; + + void start(); + + virtual void connectSuccess() noexcept; + + virtual void connectErr(const AsyncSocketException& ex) + noexcept ; + + virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept; + + virtual void handshakeErr( + AsyncSSLSocket* sock, + const AsyncSocketException& ex) noexcept; + +}; + +int +main(int argc, char* argv[]) +{ + gflags::SetUsageMessage(std::string("\n\n" +"usage: sslcachetest [options] -c -t servers\n" +)); + gflags::ParseCommandLineFlags(&argc, &argv, true); + int reqs = 0; + int hits = 0; + int miss = 0; + struct timeval start; + struct timeval end; + struct timeval result; + + srand((unsigned int)time(nullptr)); + + for (int i = 1; i < argc; i++) { + f_servers[f_num_servers++] = argv[i]; + } + if (f_num_servers == 0) { + cout << "require at least one server\n"; + return 1; + } + + gettimeofday(&start, nullptr); + if (FLAGS_threads == 1) { + ClientRunner r; + r.run(); + gettimeofday(&end, nullptr); + reqs = r.reqs; + hits = r.hits; + miss = r.miss; + } + else { + std::vector clients; + std::vector threads; + for (int t = 0; t < FLAGS_threads; t++) { + threads.emplace_back([&] { + clients[t].run(); + }); + } + for (auto& thr: threads) { + thr.join(); + } + gettimeofday(&end, nullptr); + + for (const auto& client: clients) { + reqs += client.reqs; + hits += client.hits; + miss += client.miss; + } + } + + timersub(&end, &start, &result); + + cout << "Requests: " << reqs << endl; + cout << "Handshakes: " << miss << endl; + cout << "Resumes: " << hits << endl; + cout << "Runtime(ms): " << result.tv_sec << "." << result.tv_usec / 1000 << + endl; + + cout << "ops/sec: " << (reqs * 1.0) / + ((double)result.tv_sec * 1.0 + (double)result.tv_usec / 1000000.0) << endl; + + return 0; +} + +void +ClientRunner::run() +{ + EventBase eb; + std::list clients; + SSL_SESSION* session = nullptr; + + for (int i = 0; i < FLAGS_clients; i++) { + SSLCacheClient* c = new SSLCacheClient(&eb, &session, this); + c->start(); + clients.push_back(c); + } + + eb.loop(); + + for (auto it = clients.begin(); it != clients.end(); it++) { + delete* it; + } + + reqs += hits + miss; +} + +SSLCacheClient::SSLCacheClient(EventBase* eb, + SSL_SESSION **pSess, + ClientRunner* cr) + : eventBase_(eb), + currReq_(0), + serverIdx_(0), + socket_(nullptr), + sslSocket_(nullptr), + session_(nullptr), + pSess_(pSess), + cr_(cr) +{ + ctx_.reset(new SSLContext()); + ctx_->setOptions(SSL_OP_NO_TICKET); +} + +void +SSLCacheClient::start() +{ + if (currReq_ >= FLAGS_requests) { + cout << "+"; + return; + } + + if (currReq_ == 0 || !FLAGS_sticky) { + serverIdx_ = rand() % f_num_servers; + } + if (socket_ != nullptr) { + if (sslSocket_ != nullptr) { + sslSocket_->destroy(); + sslSocket_ = nullptr; + } + socket_->destroy(); + socket_ = nullptr; + } + socket_ = new AsyncSocket(eventBase_); + socket_->connect(this, f_servers[serverIdx_], (uint16_t)FLAGS_port); +} + +void +SSLCacheClient::connectSuccess() noexcept +{ + sslSocket_ = new AsyncSSLSocket(ctx_, eventBase_, socket_->detachFd(), + false); + + if (!FLAGS_handshakes) { + if (session_ != nullptr) + sslSocket_->setSSLSession(session_); + else if (FLAGS_global && pSess_ != nullptr) + sslSocket_->setSSLSession(*pSess_); + } + sslSocket_->sslConn(this); +} + +void +SSLCacheClient::connectErr(const AsyncSocketException& ex) + noexcept +{ + cout << "connectError: " << ex.what() << endl; +} + +void +SSLCacheClient::handshakeSuc(AsyncSSLSocket* socket) noexcept +{ + if (sslSocket_->getSSLSessionReused()) { + cr_->hits++; + } else { + cr_->miss++; + if (session_ != nullptr) { + SSL_SESSION_free(session_); + } + session_ = sslSocket_->getSSLSession(); + if (FLAGS_global && pSess_ != nullptr && *pSess_ == nullptr) { + *pSess_ = session_; + } + } + if ( ((cr_->hits + cr_->miss) % 100) == ((100 / FLAGS_threads) * cr_->num)) { + cout << "."; + cout.flush(); + } + sslSocket_->closeNow(); + currReq_++; + this->start(); +} + +void +SSLCacheClient::handshakeErr( + AsyncSSLSocket* sock, + const AsyncSocketException& ex) + noexcept +{ + cout << "handshakeError: " << ex.what() << endl; +} diff --git a/folly/wangle/ssl/test/SSLContextManagerTest.cpp b/folly/wangle/ssl/test/SSLContextManagerTest.cpp new file mode 100644 index 00000000..3dd9b29e --- /dev/null +++ b/folly/wangle/ssl/test/SSLContextManagerTest.cpp @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2014, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include +#include +#include +#include +#include +#include + +using std::shared_ptr; + +namespace folly { + +TEST(SSLContextManagerTest, Test1) +{ + EventBase eventBase; + SSLContextManager sslCtxMgr(&eventBase, "vip_ssl_context_manager_test_", + true, nullptr); + auto www_facebook_com_ctx = std::make_shared(); + auto start_facebook_com_ctx = std::make_shared(); + auto start_abc_facebook_com_ctx = std::make_shared(); + + sslCtxMgr.insertSSLCtxByDomainName( + "www.facebook.com", + strlen("www.facebook.com"), + www_facebook_com_ctx); + sslCtxMgr.insertSSLCtxByDomainName( + "www.facebook.com", + strlen("www.facebook.com"), + www_facebook_com_ctx); + try { + sslCtxMgr.insertSSLCtxByDomainName( + "www.facebook.com", + strlen("www.facebook.com"), + std::make_shared()); + } catch (const std::exception& ex) { + } + sslCtxMgr.insertSSLCtxByDomainName( + "*.facebook.com", + strlen("*.facebook.com"), + start_facebook_com_ctx); + sslCtxMgr.insertSSLCtxByDomainName( + "*.abc.facebook.com", + strlen("*.abc.facebook.com"), + start_abc_facebook_com_ctx); + try { + sslCtxMgr.insertSSLCtxByDomainName( + "*.abc.facebook.com", + strlen("*.abc.facebook.com"), + std::make_shared()); + FAIL(); + } catch (const std::exception& ex) { + } + + shared_ptr retCtx; + retCtx = sslCtxMgr.getSSLCtx(DNString("www.facebook.com")); + EXPECT_EQ(retCtx, www_facebook_com_ctx); + retCtx = sslCtxMgr.getSSLCtx(DNString("WWW.facebook.com")); + EXPECT_EQ(retCtx, www_facebook_com_ctx); + EXPECT_FALSE(sslCtxMgr.getSSLCtx(DNString("xyz.facebook.com"))); + + retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("xyz.facebook.com")); + EXPECT_EQ(retCtx, start_facebook_com_ctx); + retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("XYZ.facebook.com")); + EXPECT_EQ(retCtx, start_facebook_com_ctx); + + retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("www.abc.facebook.com")); + EXPECT_EQ(retCtx, start_abc_facebook_com_ctx); + + // ensure "facebook.com" does not match "*.facebook.com" + EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("facebook.com"))); + // ensure "Xfacebook.com" does not match "*.facebook.com" + EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("Xfacebook.com"))); + // ensure wildcard name only matches one domain up + EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("abc.xyz.facebook.com"))); + + eventBase.loop(); // Clean up events before SSLContextManager is destructed +} + +}