/**
* Create a server/client AsyncSSLSocket
*/
-AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
- EventBase* evb, int fd, bool server,
- bool deferSecurityNegotiation) :
- AsyncSocket(evb, fd),
- server_(server),
- ctx_(ctx),
- handshakeTimeout_(this, evb),
- connectionTimeout_(this, evb) {
+AsyncSSLSocket::AsyncSSLSocket(
+ const shared_ptr<SSLContext>& ctx,
+ EventBase* evb,
+ int fd,
+ bool server,
+ bool deferSecurityNegotiation)
+ : AsyncSocket(evb, fd),
+ server_(server),
+ ctx_(ctx),
+ handshakeTimeout_(this, evb),
+ connectionTimeout_(this, evb) {
+ noTransparentTls_ = true;
+ init();
+ if (server) {
+ SSL_CTX_set_info_callback(
+ ctx_->getSSLCtx(), AsyncSSLSocket::sslInfoCallback);
+ }
+ if (deferSecurityNegotiation) {
+ sslState_ = STATE_UNENCRYPTED;
+ }
+}
+
+AsyncSSLSocket::AsyncSSLSocket(
+ const shared_ptr<SSLContext>& ctx,
+ AsyncSocket::UniquePtr oldAsyncSocket,
+ bool server,
+ bool deferSecurityNegotiation)
+ : AsyncSocket(std::move(oldAsyncSocket)),
+ server_(server),
+ ctx_(ctx),
+ handshakeTimeout_(this, oldAsyncSocket->getEventBase()),
+ connectionTimeout_(this, oldAsyncSocket->getEventBase()) {
noTransparentTls_ = true;
init();
if (server) {
* Create a client AsyncSSLSocket from an already connected fd
* and allow tlsext_hostname to be sent in Client Hello.
*/
-AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
- EventBase* evb, int fd,
- const std::string& serverName,
- bool deferSecurityNegotiation) :
- AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
+AsyncSSLSocket::AsyncSSLSocket(
+ const shared_ptr<SSLContext>& ctx,
+ EventBase* evb,
+ int fd,
+ const std::string& serverName,
+ bool deferSecurityNegotiation)
+ : AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
tlsextHostname_ = serverName;
}
#endif // FOLLY_OPENSSL_HAS_SNI
/* register for a read operation (waiting for CLIENT HELLO) */
updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
- if (preReceivedData_) {
- handleRead();
- }
+ checkForImmediateRead();
}
#if OPENSSL_VERSION_NUMBER >= 0x009080bfL
// the socket to become readable again.
if (ssl_ != nullptr && SSL_pending(ssl_) > 0) {
AsyncSocket::handleRead();
+ } else {
+ AsyncSocket::checkForImmediateRead();
}
}
preverifyOk;
}
-void AsyncSSLSocket::setPreReceivedData(std::unique_ptr<IOBuf> data) {
- CHECK(sslState_ == STATE_UNINIT || sslState_ == STATE_UNENCRYPTED);
- CHECK(!preReceivedData_);
- preReceivedData_ = std::move(data);
-}
-
void AsyncSSLSocket::enableClientHelloParsing() {
parseClientHello_ = true;
clientHelloInfo_.reset(new ssl::ClientHelloInfo());
* @param deferSecurityNegotiation
* unencrypted data can be sent before sslConn/Accept
*/
- AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
- EventBase* evb, int fd,
- bool server = true, bool deferSecurityNegotiation = false);
+ AsyncSSLSocket(
+ const std::shared_ptr<folly::SSLContext>& ctx,
+ EventBase* evb,
+ int fd,
+ bool server = true,
+ bool deferSecurityNegotiation = false);
+ /**
+ * Create a server/client AsyncSSLSocket from an already connected
+ * AsyncSocket.
+ */
+ AsyncSSLSocket(
+ const std::shared_ptr<folly::SSLContext>& ctx,
+ AsyncSocket::UniquePtr oldAsyncSocket,
+ bool server = true,
+ bool deferSecurityNegotiation = false);
/**
* Helper function to create a server/client shared_ptr<AsyncSSLSocket>.
* @param fd File descriptor to take over (should be a connected socket).
* @param serverName tlsext_hostname that will be sent in ClientHello.
*/
- AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
- EventBase* evb,
- int fd,
- const std::string& serverName,
- bool deferSecurityNegotiation = false);
+ AsyncSSLSocket(
+ const std::shared_ptr<folly::SSLContext>& ctx,
+ EventBase* evb,
+ int fd,
+ const std::string& serverName,
+ bool deferSecurityNegotiation = false);
static std::shared_ptr<AsyncSSLSocket> newSocket(
const std::shared_ptr<folly::SSLContext>& ctx,
virtual size_t getRawBytesReceived() const override;
void enableClientHelloParsing();
- void setPreReceivedData(std::unique_ptr<IOBuf> data);
-
/**
* Accept an SSL connection on the socket.
*
bool sessionResumptionAttempted_{false};
std::chrono::milliseconds totalConnectTimeout_{0};
- std::unique_ptr<IOBuf> preReceivedData_;
std::string sslVerificationAlert_;
};
#include <folly/io/async/AsyncSocket.h>
#include <folly/ExceptionWrapper.h>
+#include <folly/Portability.h>
#include <folly/SocketAddress.h>
+#include <folly/io/Cursor.h>
#include <folly/io/IOBuf.h>
-#include <folly/Portability.h>
+#include <folly/io/IOBufQueue.h>
#include <folly/portability/Fcntl.h>
#include <folly/portability/Sockets.h>
#include <folly/portability/SysUio.h>
state_ = StateEnum::ESTABLISHED;
}
+AsyncSocket::AsyncSocket(AsyncSocket::UniquePtr oldAsyncSocket)
+ : AsyncSocket(oldAsyncSocket->getEventBase(), oldAsyncSocket->detachFd()) {
+ preReceivedData_ = std::move(oldAsyncSocket->preReceivedData_);
+}
+
// init() method, since constructor forwarding isn't supported in most
// compilers yet.
void AsyncSocket::init() {
VLOG(5) << "AsyncSocket::performRead() this=" << this << ", buf=" << *buf
<< ", buflen=" << *buflen;
- int recvFlags = 0;
- if (peek_) {
- recvFlags |= MSG_PEEK;
+ if (preReceivedData_ && !preReceivedData_->empty()) {
+ VLOG(5) << "AsyncSocket::performRead() this=" << this
+ << ", reading pre-received data";
+
+ io::Cursor cursor(preReceivedData_.get());
+ auto len = cursor.pullAtMost(*buf, *buflen);
+
+ IOBufQueue queue;
+ queue.append(std::move(preReceivedData_));
+ queue.trimStart(len);
+ preReceivedData_ = queue.move();
+
+ appBytesReceived_ += len;
+ return ReadResult(len);
}
- ssize_t bytes = recv(fd_, *buf, *buflen, MSG_DONTWAIT | recvFlags);
+ ssize_t bytes = recv(fd_, *buf, *buflen, MSG_DONTWAIT);
if (bytes < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// No more data to read right now.
// be a pessimism. In most cases it probably wouldn't be readable, and we
// would just waste an extra system call. Even if it is readable, waiting to
// find out from libevent on the next event loop doesn't seem that bad.
+ //
+ // The exception to this is if we have pre-received data. In that case there
+ // is definitely data available immediately.
+ if (preReceivedData_ && !preReceivedData_->empty()) {
+ handleRead();
+ }
}
void AsyncSocket::handleInitialReadWrite() noexcept {
*/
AsyncSocket(EventBase* evb, int fd);
+ /**
+ * Create an AsyncSocket from a different, already connected AsyncSocket.
+ *
+ * Similar to AsyncSocket(evb, fd) when fd was previously owned by an
+ * AsyncSocket.
+ */
+ explicit AsyncSocket(AsyncSocket::UniquePtr);
+
/**
* Helper function to create a shared_ptr<AsyncSocket>.
*
* error. The AsyncSocket may no longer be used after the file descriptor
* has been extracted.
*
+ * This method should be used with care as the resulting fd is not guaranteed
+ * to perfectly reflect the state of the AsyncSocket (security state,
+ * pre-received data, etc.).
+ *
* Returns the file descriptor. The caller assumes ownership of the
* descriptor, and it will not be closed when the AsyncSocket is destroyed.
*/
return setsockopt(fd_, level, optname, optval, sizeof(T));
}
- virtual void setPeek(bool peek) {
- peek_ = peek;
+ /**
+ * Set pre-received data, to be returned to read callback before any data
+ * from the socket.
+ */
+ virtual void setPreReceivedData(std::unique_ptr<IOBuf> data) {
+ if (preReceivedData_) {
+ preReceivedData_->prependChain(std::move(data));
+ } else {
+ preReceivedData_ = std::move(data);
+ }
}
/**
size_t appBytesWritten_; ///< Num of bytes written to socket
bool isBufferMovable_{false};
- bool peek_{false}; // Peek bytes.
+ // Pre-received data, to be returned to read callback before any data from the
+ // socket.
+ std::unique_ptr<IOBuf> preReceivedData_;
int8_t readErr_{READ_NO_ERROR}; ///< The read error encountered, if any.
ASSERT_TRUE(errMsgCB.gotTimestamp_);
}
#endif // MSG_ERRQUEUE
+
+TEST(AsyncSocket, PreReceivedData) {
+ TestServer server;
+
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ socket->connect(nullptr, server.getAddress(), 30);
+ evb.loop();
+
+ socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
+
+ auto acceptedSocket = server.acceptAsync(&evb);
+
+ ReadCallback peekCallback(2);
+ ReadCallback readCallback;
+ peekCallback.dataAvailableCallback = [&]() {
+ peekCallback.verifyData("he", 2);
+ acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("h"));
+ acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("e"));
+ acceptedSocket->setReadCB(nullptr);
+ acceptedSocket->setReadCB(&readCallback);
+ };
+ readCallback.dataAvailableCallback = [&]() {
+ if (readCallback.dataRead() == 5) {
+ readCallback.verifyData("hello", 5);
+ acceptedSocket->setReadCB(nullptr);
+ }
+ };
+
+ acceptedSocket->setReadCB(&peekCallback);
+
+ evb.loop();
+}
+
+TEST(AsyncSocket, PreReceivedDataOnly) {
+ TestServer server;
+
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ socket->connect(nullptr, server.getAddress(), 30);
+ evb.loop();
+
+ socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
+
+ auto acceptedSocket = server.acceptAsync(&evb);
+
+ ReadCallback peekCallback;
+ ReadCallback readCallback;
+ peekCallback.dataAvailableCallback = [&]() {
+ peekCallback.verifyData("hello", 5);
+ acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
+ acceptedSocket->setReadCB(&readCallback);
+ };
+ readCallback.dataAvailableCallback = [&]() {
+ readCallback.verifyData("hello", 5);
+ acceptedSocket->setReadCB(nullptr);
+ };
+
+ acceptedSocket->setReadCB(&peekCallback);
+
+ evb.loop();
+}
+
+TEST(AsyncSocket, PreReceivedDataPartial) {
+ TestServer server;
+
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ socket->connect(nullptr, server.getAddress(), 30);
+ evb.loop();
+
+ socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
+
+ auto acceptedSocket = server.acceptAsync(&evb);
+
+ ReadCallback peekCallback;
+ ReadCallback smallReadCallback(3);
+ ReadCallback normalReadCallback;
+ peekCallback.dataAvailableCallback = [&]() {
+ peekCallback.verifyData("hello", 5);
+ acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
+ acceptedSocket->setReadCB(&smallReadCallback);
+ };
+ smallReadCallback.dataAvailableCallback = [&]() {
+ smallReadCallback.verifyData("hel", 3);
+ acceptedSocket->setReadCB(&normalReadCallback);
+ };
+ normalReadCallback.dataAvailableCallback = [&]() {
+ normalReadCallback.verifyData("lo", 2);
+ acceptedSocket->setReadCB(nullptr);
+ };
+
+ acceptedSocket->setReadCB(&peekCallback);
+
+ evb.loop();
+}
+
+TEST(AsyncSocket, PreReceivedDataTakeover) {
+ TestServer server;
+
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ socket->connect(nullptr, server.getAddress(), 30);
+ evb.loop();
+
+ socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
+
+ auto acceptedSocket =
+ AsyncSocket::UniquePtr(new AsyncSocket(&evb, server.acceptFD()));
+ AsyncSocket::UniquePtr takeoverSocket;
+
+ ReadCallback peekCallback(3);
+ ReadCallback readCallback;
+ peekCallback.dataAvailableCallback = [&]() {
+ peekCallback.verifyData("hel", 3);
+ acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
+ acceptedSocket->setReadCB(nullptr);
+ takeoverSocket =
+ AsyncSocket::UniquePtr(new AsyncSocket(std::move(acceptedSocket)));
+ takeoverSocket->setReadCB(&readCallback);
+ };
+ readCallback.dataAvailableCallback = [&]() {
+ readCallback.verifyData("hello", 5);
+ takeoverSocket->setReadCB(nullptr);
+ };
+
+ acceptedSocket->setReadCB(&peekCallback);
+
+ evb.loop();
+}
bool(const unsigned char**,
unsigned*,
SSLContext::NextProtocolType*));
- MOCK_METHOD1(setPeek, void(bool));
MOCK_METHOD1(setReadCB, void(ReadCallback*));
void sslConn(
MOCK_CONST_METHOD0(good, bool());
MOCK_CONST_METHOD0(readable, bool());
MOCK_CONST_METHOD0(hangup, bool());
- MOCK_METHOD1(setPeek, void(bool));
MOCK_METHOD1(setReadCB, void(ReadCallback*));
+ MOCK_METHOD1(_setPreReceivedData, void(std::unique_ptr<IOBuf>&));
+ void setPreReceivedData(std::unique_ptr<IOBuf> data) override {
+ return _setPreReceivedData(data);
+ }
};
}}