From: Kyle Nekritz Date: Thu, 9 Mar 2017 16:25:27 +0000 (-0800) Subject: Replace MSG_PEEK with a pre-received data interface. X-Git-Tag: v2017.03.13.00~9 X-Git-Url: http://demsky.eecs.uci.edu/git/?a=commitdiff_plain;h=fb9776fcdb9676e1bfd06afecc8f8c9e74f28439;p=folly.git Replace MSG_PEEK with a pre-received data interface. Summary: MSG_PEEK was difficult if not impossible to use well since we do not provide a way wait for more data to arrive. If you are using setPeek on AsyncSocket, and you do not receive the amount of data you want, you must either abandon your peek attempt, or spin around the event base waiting for more data. This diff replaces the peek interface on AsyncSocket with a pre-received data interface, allowing users to insert data back onto the front of connections after reading some data in another layer. Reviewed By: djwatson Differential Revision: D4626315 fbshipit-source-id: c552e64f5b3ac9e40ea3358d65b4b9db848f5d74 --- diff --git a/folly/io/async/AsyncSSLSocket.cpp b/folly/io/async/AsyncSSLSocket.cpp index a25269d2..3b99a4c4 100644 --- a/folly/io/async/AsyncSSLSocket.cpp +++ b/folly/io/async/AsyncSSLSocket.cpp @@ -218,14 +218,38 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr &ctx, /** * Create a server/client AsyncSSLSocket */ -AsyncSSLSocket::AsyncSSLSocket(const shared_ptr& ctx, - EventBase* evb, int fd, bool server, - bool deferSecurityNegotiation) : - AsyncSocket(evb, fd), - server_(server), - ctx_(ctx), - handshakeTimeout_(this, evb), - connectionTimeout_(this, evb) { +AsyncSSLSocket::AsyncSSLSocket( + const shared_ptr& ctx, + EventBase* evb, + int fd, + bool server, + bool deferSecurityNegotiation) + : AsyncSocket(evb, fd), + server_(server), + ctx_(ctx), + handshakeTimeout_(this, evb), + connectionTimeout_(this, evb) { + noTransparentTls_ = true; + init(); + if (server) { + SSL_CTX_set_info_callback( + ctx_->getSSLCtx(), AsyncSSLSocket::sslInfoCallback); + } + if (deferSecurityNegotiation) { + sslState_ = STATE_UNENCRYPTED; + } +} + +AsyncSSLSocket::AsyncSSLSocket( + const shared_ptr& ctx, + AsyncSocket::UniquePtr oldAsyncSocket, + bool server, + bool deferSecurityNegotiation) + : AsyncSocket(std::move(oldAsyncSocket)), + server_(server), + ctx_(ctx), + handshakeTimeout_(this, oldAsyncSocket->getEventBase()), + connectionTimeout_(this, oldAsyncSocket->getEventBase()) { noTransparentTls_ = true; init(); if (server) { @@ -254,11 +278,13 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr &ctx, * Create a client AsyncSSLSocket from an already connected fd * and allow tlsext_hostname to be sent in Client Hello. */ -AsyncSSLSocket::AsyncSSLSocket(const shared_ptr& ctx, - EventBase* evb, int fd, - const std::string& serverName, - bool deferSecurityNegotiation) : - AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) { +AsyncSSLSocket::AsyncSSLSocket( + const shared_ptr& ctx, + EventBase* evb, + int fd, + const std::string& serverName, + bool deferSecurityNegotiation) + : AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) { tlsextHostname_ = serverName; } #endif // FOLLY_OPENSSL_HAS_SNI @@ -451,9 +477,7 @@ void AsyncSSLSocket::sslAccept( /* register for a read operation (waiting for CLIENT HELLO) */ updateEventRegistration(EventHandler::READ, EventHandler::WRITE); - if (preReceivedData_) { - handleRead(); - } + checkForImmediateRead(); } #if OPENSSL_VERSION_NUMBER >= 0x009080bfL @@ -985,6 +1009,8 @@ void AsyncSSLSocket::checkForImmediateRead() noexcept { // the socket to become readable again. if (ssl_ != nullptr && SSL_pending(ssl_) > 0) { AsyncSocket::handleRead(); + } else { + AsyncSocket::checkForImmediateRead(); } } @@ -1684,12 +1710,6 @@ int AsyncSSLSocket::sslVerifyCallback( preverifyOk; } -void AsyncSSLSocket::setPreReceivedData(std::unique_ptr data) { - CHECK(sslState_ == STATE_UNINIT || sslState_ == STATE_UNENCRYPTED); - CHECK(!preReceivedData_); - preReceivedData_ = std::move(data); -} - void AsyncSSLSocket::enableClientHelloParsing() { parseClientHello_ = true; clientHelloInfo_.reset(new ssl::ClientHelloInfo()); diff --git a/folly/io/async/AsyncSSLSocket.h b/folly/io/async/AsyncSSLSocket.h index a8bb1e12..2121c2ff 100644 --- a/folly/io/async/AsyncSSLSocket.h +++ b/folly/io/async/AsyncSSLSocket.h @@ -173,10 +173,22 @@ class AsyncSSLSocket : public virtual AsyncSocket { * @param deferSecurityNegotiation * unencrypted data can be sent before sslConn/Accept */ - AsyncSSLSocket(const std::shared_ptr& ctx, - EventBase* evb, int fd, - bool server = true, bool deferSecurityNegotiation = false); + AsyncSSLSocket( + const std::shared_ptr& ctx, + EventBase* evb, + int fd, + bool server = true, + bool deferSecurityNegotiation = false); + /** + * Create a server/client AsyncSSLSocket from an already connected + * AsyncSocket. + */ + AsyncSSLSocket( + const std::shared_ptr& ctx, + AsyncSocket::UniquePtr oldAsyncSocket, + bool server = true, + bool deferSecurityNegotiation = false); /** * Helper function to create a server/client shared_ptr. @@ -227,11 +239,12 @@ class AsyncSSLSocket : public virtual AsyncSocket { * @param fd File descriptor to take over (should be a connected socket). * @param serverName tlsext_hostname that will be sent in ClientHello. */ - AsyncSSLSocket(const std::shared_ptr& ctx, - EventBase* evb, - int fd, - const std::string& serverName, - bool deferSecurityNegotiation = false); + AsyncSSLSocket( + const std::shared_ptr& ctx, + EventBase* evb, + int fd, + const std::string& serverName, + bool deferSecurityNegotiation = false); static std::shared_ptr newSocket( const std::shared_ptr& ctx, @@ -276,8 +289,6 @@ class AsyncSSLSocket : public virtual AsyncSocket { virtual size_t getRawBytesReceived() const override; void enableClientHelloParsing(); - void setPreReceivedData(std::unique_ptr data); - /** * Accept an SSL connection on the socket. * @@ -864,7 +875,6 @@ class AsyncSSLSocket : public virtual AsyncSocket { bool sessionResumptionAttempted_{false}; std::chrono::milliseconds totalConnectTimeout_{0}; - std::unique_ptr preReceivedData_; std::string sslVerificationAlert_; }; diff --git a/folly/io/async/AsyncSocket.cpp b/folly/io/async/AsyncSocket.cpp index 721686c5..827ee8e9 100644 --- a/folly/io/async/AsyncSocket.cpp +++ b/folly/io/async/AsyncSocket.cpp @@ -17,9 +17,11 @@ #include #include +#include #include +#include #include -#include +#include #include #include #include @@ -229,6 +231,11 @@ AsyncSocket::AsyncSocket(EventBase* evb, int fd) state_ = StateEnum::ESTABLISHED; } +AsyncSocket::AsyncSocket(AsyncSocket::UniquePtr oldAsyncSocket) + : AsyncSocket(oldAsyncSocket->getEventBase(), oldAsyncSocket->detachFd()) { + preReceivedData_ = std::move(oldAsyncSocket->preReceivedData_); +} + // init() method, since constructor forwarding isn't supported in most // compilers yet. void AsyncSocket::init() { @@ -1406,12 +1413,23 @@ AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) { VLOG(5) << "AsyncSocket::performRead() this=" << this << ", buf=" << *buf << ", buflen=" << *buflen; - int recvFlags = 0; - if (peek_) { - recvFlags |= MSG_PEEK; + if (preReceivedData_ && !preReceivedData_->empty()) { + VLOG(5) << "AsyncSocket::performRead() this=" << this + << ", reading pre-received data"; + + io::Cursor cursor(preReceivedData_.get()); + auto len = cursor.pullAtMost(*buf, *buflen); + + IOBufQueue queue; + queue.append(std::move(preReceivedData_)); + queue.trimStart(len); + preReceivedData_ = queue.move(); + + appBytesReceived_ += len; + return ReadResult(len); } - ssize_t bytes = recv(fd_, *buf, *buflen, MSG_DONTWAIT | recvFlags); + ssize_t bytes = recv(fd_, *buf, *buflen, MSG_DONTWAIT); if (bytes < 0) { if (errno == EAGAIN || errno == EWOULDBLOCK) { // No more data to read right now. @@ -1762,6 +1780,12 @@ void AsyncSocket::checkForImmediateRead() noexcept { // be a pessimism. In most cases it probably wouldn't be readable, and we // would just waste an extra system call. Even if it is readable, waiting to // find out from libevent on the next event loop doesn't seem that bad. + // + // The exception to this is if we have pre-received data. In that case there + // is definitely data available immediately. + if (preReceivedData_ && !preReceivedData_->empty()) { + handleRead(); + } } void AsyncSocket::handleInitialReadWrite() noexcept { diff --git a/folly/io/async/AsyncSocket.h b/folly/io/async/AsyncSocket.h index 54b37705..3e2adbce 100644 --- a/folly/io/async/AsyncSocket.h +++ b/folly/io/async/AsyncSocket.h @@ -189,6 +189,14 @@ class AsyncSocket : virtual public AsyncTransportWrapper { */ AsyncSocket(EventBase* evb, int fd); + /** + * Create an AsyncSocket from a different, already connected AsyncSocket. + * + * Similar to AsyncSocket(evb, fd) when fd was previously owned by an + * AsyncSocket. + */ + explicit AsyncSocket(AsyncSocket::UniquePtr); + /** * Helper function to create a shared_ptr. * @@ -264,6 +272,10 @@ class AsyncSocket : virtual public AsyncTransportWrapper { * error. The AsyncSocket may no longer be used after the file descriptor * has been extracted. * + * This method should be used with care as the resulting fd is not guaranteed + * to perfectly reflect the state of the AsyncSocket (security state, + * pre-received data, etc.). + * * Returns the file descriptor. The caller assumes ownership of the * descriptor, and it will not be closed when the AsyncSocket is destroyed. */ @@ -601,8 +613,16 @@ class AsyncSocket : virtual public AsyncTransportWrapper { return setsockopt(fd_, level, optname, optval, sizeof(T)); } - virtual void setPeek(bool peek) { - peek_ = peek; + /** + * Set pre-received data, to be returned to read callback before any data + * from the socket. + */ + virtual void setPreReceivedData(std::unique_ptr data) { + if (preReceivedData_) { + preReceivedData_->prependChain(std::move(data)); + } else { + preReceivedData_ = std::move(data); + } } /** @@ -998,7 +1018,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper { size_t appBytesWritten_; ///< Num of bytes written to socket bool isBufferMovable_{false}; - bool peek_{false}; // Peek bytes. + // Pre-received data, to be returned to read callback before any data from the + // socket. + std::unique_ptr preReceivedData_; int8_t readErr_{READ_NO_ERROR}; ///< The read error encountered, if any. diff --git a/folly/io/async/test/AsyncSocketTest2.cpp b/folly/io/async/test/AsyncSocketTest2.cpp index f5323151..864a4031 100644 --- a/folly/io/async/test/AsyncSocketTest2.cpp +++ b/folly/io/async/test/AsyncSocketTest2.cpp @@ -2909,3 +2909,133 @@ TEST(AsyncSocketTest, ErrMessageCallback) { ASSERT_TRUE(errMsgCB.gotTimestamp_); } #endif // MSG_ERRQUEUE + +TEST(AsyncSocket, PreReceivedData) { + TestServer server; + + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + socket->connect(nullptr, server.getAddress(), 30); + evb.loop(); + + socket->writeChain(nullptr, IOBuf::copyBuffer("hello")); + + auto acceptedSocket = server.acceptAsync(&evb); + + ReadCallback peekCallback(2); + ReadCallback readCallback; + peekCallback.dataAvailableCallback = [&]() { + peekCallback.verifyData("he", 2); + acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("h")); + acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("e")); + acceptedSocket->setReadCB(nullptr); + acceptedSocket->setReadCB(&readCallback); + }; + readCallback.dataAvailableCallback = [&]() { + if (readCallback.dataRead() == 5) { + readCallback.verifyData("hello", 5); + acceptedSocket->setReadCB(nullptr); + } + }; + + acceptedSocket->setReadCB(&peekCallback); + + evb.loop(); +} + +TEST(AsyncSocket, PreReceivedDataOnly) { + TestServer server; + + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + socket->connect(nullptr, server.getAddress(), 30); + evb.loop(); + + socket->writeChain(nullptr, IOBuf::copyBuffer("hello")); + + auto acceptedSocket = server.acceptAsync(&evb); + + ReadCallback peekCallback; + ReadCallback readCallback; + peekCallback.dataAvailableCallback = [&]() { + peekCallback.verifyData("hello", 5); + acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello")); + acceptedSocket->setReadCB(&readCallback); + }; + readCallback.dataAvailableCallback = [&]() { + readCallback.verifyData("hello", 5); + acceptedSocket->setReadCB(nullptr); + }; + + acceptedSocket->setReadCB(&peekCallback); + + evb.loop(); +} + +TEST(AsyncSocket, PreReceivedDataPartial) { + TestServer server; + + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + socket->connect(nullptr, server.getAddress(), 30); + evb.loop(); + + socket->writeChain(nullptr, IOBuf::copyBuffer("hello")); + + auto acceptedSocket = server.acceptAsync(&evb); + + ReadCallback peekCallback; + ReadCallback smallReadCallback(3); + ReadCallback normalReadCallback; + peekCallback.dataAvailableCallback = [&]() { + peekCallback.verifyData("hello", 5); + acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello")); + acceptedSocket->setReadCB(&smallReadCallback); + }; + smallReadCallback.dataAvailableCallback = [&]() { + smallReadCallback.verifyData("hel", 3); + acceptedSocket->setReadCB(&normalReadCallback); + }; + normalReadCallback.dataAvailableCallback = [&]() { + normalReadCallback.verifyData("lo", 2); + acceptedSocket->setReadCB(nullptr); + }; + + acceptedSocket->setReadCB(&peekCallback); + + evb.loop(); +} + +TEST(AsyncSocket, PreReceivedDataTakeover) { + TestServer server; + + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + socket->connect(nullptr, server.getAddress(), 30); + evb.loop(); + + socket->writeChain(nullptr, IOBuf::copyBuffer("hello")); + + auto acceptedSocket = + AsyncSocket::UniquePtr(new AsyncSocket(&evb, server.acceptFD())); + AsyncSocket::UniquePtr takeoverSocket; + + ReadCallback peekCallback(3); + ReadCallback readCallback; + peekCallback.dataAvailableCallback = [&]() { + peekCallback.verifyData("hel", 3); + acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello")); + acceptedSocket->setReadCB(nullptr); + takeoverSocket = + AsyncSocket::UniquePtr(new AsyncSocket(std::move(acceptedSocket))); + takeoverSocket->setReadCB(&readCallback); + }; + readCallback.dataAvailableCallback = [&]() { + readCallback.verifyData("hello", 5); + takeoverSocket->setReadCB(nullptr); + }; + + acceptedSocket->setReadCB(&peekCallback); + + evb.loop(); +} diff --git a/folly/io/async/test/MockAsyncSSLSocket.h b/folly/io/async/test/MockAsyncSSLSocket.h index 7ab4fda3..a627ba46 100644 --- a/folly/io/async/test/MockAsyncSSLSocket.h +++ b/folly/io/async/test/MockAsyncSSLSocket.h @@ -50,7 +50,6 @@ class MockAsyncSSLSocket : public AsyncSSLSocket { bool(const unsigned char**, unsigned*, SSLContext::NextProtocolType*)); - MOCK_METHOD1(setPeek, void(bool)); MOCK_METHOD1(setReadCB, void(ReadCallback*)); void sslConn( diff --git a/folly/io/async/test/MockAsyncSocket.h b/folly/io/async/test/MockAsyncSocket.h index cf55874e..d6cb3da2 100644 --- a/folly/io/async/test/MockAsyncSocket.h +++ b/folly/io/async/test/MockAsyncSocket.h @@ -45,8 +45,11 @@ class MockAsyncSocket : public AsyncSocket { MOCK_CONST_METHOD0(good, bool()); MOCK_CONST_METHOD0(readable, bool()); MOCK_CONST_METHOD0(hangup, bool()); - MOCK_METHOD1(setPeek, void(bool)); MOCK_METHOD1(setReadCB, void(ReadCallback*)); + MOCK_METHOD1(_setPreReceivedData, void(std::unique_ptr&)); + void setPreReceivedData(std::unique_ptr data) override { + return _setPreReceivedData(data); + } }; }}