From 539129455f6c8300c5bd5b72379b29823266c7d6 Mon Sep 17 00:00:00 2001 From: Subodh Iyengar Date: Mon, 30 May 2016 21:08:33 -0700 Subject: [PATCH] Add support for TFO connections Summary: This adds support to establish connections over TFO. The API introduced here retains the same connect() + write() api that clients currently use. If enableTFO() is called then the connect will be deferred to the first write. This only works with request response protocols since a write must trigger the connect. There is a tradeoff here for the simpler API, and we can address this with other signals such as a short timeout in the future. Even though the client might enable TFO, the program might run on machines without TFO support. There were 2 choices for supporting machines where TFO might not be enabled: 1. Fallback to normal connect if tfo sendmsg fails 2. Check tfo supported on the machine before using it Both these have their tradeoffs, however option 1 does not require us to read from procfs in the common code path. Reviewed By: Orvid Differential Revision: D3327480 fbshipit-source-id: 9ac3a0c7ad2d206b158fdc305641fedbd93aa44d --- folly/detail/SocketFastOpen.cpp | 37 +- folly/detail/SocketFastOpen.h | 16 +- folly/io/async/AsyncSocket.cpp | 206 +++++++--- folly/io/async/AsyncSocket.h | 46 ++- folly/io/async/test/AsyncSocketTest.h | 9 +- folly/io/async/test/AsyncSocketTest2.cpp | 488 ++++++++++++++++++++++- folly/io/async/test/BlockingSocket.h | 10 +- folly/io/async/test/SocketClient.cpp | 70 ++++ 8 files changed, 799 insertions(+), 83 deletions(-) create mode 100644 folly/io/async/test/SocketClient.cpp diff --git a/folly/detail/SocketFastOpen.cpp b/folly/detail/SocketFastOpen.cpp index 90938a9a..4455377e 100644 --- a/folly/detail/SocketFastOpen.cpp +++ b/folly/detail/SocketFastOpen.cpp @@ -20,21 +20,42 @@ namespace folly { namespace detail { #if FOLLY_ALLOW_TFO -ssize_t tfo_sendto( - int sockfd, - const void* buf, - size_t len, - int flags, - const struct sockaddr* dest_addr, - socklen_t addrlen) { + +#include +#include + +// Sometimes these flags are not present in the headers, +// so define them if not present. +#if !defined(MSG_FASTOPEN) +#define MSG_FASTOPEN 0x20000000 +#endif + +#if !defined(TCP_FASTOPEN) +#define TCP_FASTOPEN 23 +#endif + +ssize_t tfo_sendmsg(int sockfd, const struct msghdr* msg, int flags) { flags |= MSG_FASTOPEN; - return sendto(sockfd, buf, len, flags, dest_addr, addrlen); + return sendmsg(sockfd, msg, flags); } int tfo_enable(int sockfd, size_t max_queue_size) { return setsockopt( sockfd, SOL_TCP, TCP_FASTOPEN, &max_queue_size, sizeof(max_queue_size)); } + +#else + +ssize_t tfo_sendmsg(int sockfd, const struct msghdr* msg, int flags) { + errno = EOPNOTSUPP; + return -1; +} + +int tfo_enable(int sockfd, size_t max_queue_size) { + errno = ENOPROTOOPT; + return -1; +} + #endif } } diff --git a/folly/detail/SocketFastOpen.h b/folly/detail/SocketFastOpen.h index 1295a1e6..d01e89b9 100644 --- a/folly/detail/SocketFastOpen.h +++ b/folly/detail/SocketFastOpen.h @@ -19,31 +19,23 @@ #include #include -#if !defined(FOLLY_ALLOW_TFO) && defined(TCP_FASTOPEN) && defined(MSG_FASTOPEN) +#if !defined(FOLLY_ALLOW_TFO) && defined(__linux__) && !defined(__ANDROID__) +// only allow for linux right now #define FOLLY_ALLOW_TFO 1 #endif namespace folly { namespace detail { -#if FOLLY_ALLOW_TFO - /** - * tfo_sendto has the same semantics as sendto, but is used to + * tfo_sendto has the same semantics as sendmsg, but is used to * send with TFO data. */ -ssize_t tfo_sendto( - int sockfd, - const void* buf, - size_t len, - int flags, - const struct sockaddr* dest_addr, - socklen_t addlen); +ssize_t tfo_sendmsg(int sockfd, const struct msghdr* msg, int flags); /** * Enable TFO on a listening socket. */ int tfo_enable(int sockfd, size_t max_queue_size); -#endif } } diff --git a/folly/io/async/AsyncSocket.cpp b/folly/io/async/AsyncSocket.cpp index 0fe8e886..cffe1f82 100644 --- a/folly/io/async/AsyncSocket.cpp +++ b/folly/io/async/AsyncSocket.cpp @@ -16,8 +16,7 @@ #include -#include -#include +#include #include #include #include @@ -429,34 +428,12 @@ void AsyncSocket::connect(ConnectCallback* callback, // Perform the connect() address.getAddress(&addrStorage); - rv = ::connect(fd_, saddr, address.getActualSize()); - if (rv < 0) { - auto errnoCopy = errno; - if (errnoCopy == EINPROGRESS) { - // Connection in progress. - if (timeout > 0) { - // Start a timer in case the connection takes too long. - if (!writeTimeout_.scheduleTimeout(timeout)) { - throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, - withAddr("failed to schedule AsyncSocket connect timeout")); - } - } - - // Register for write events, so we'll - // be notified when the connection finishes/fails. - // Note that we don't register for a persistent event here. - assert(eventFlags_ == EventHandler::NONE); - eventFlags_ = EventHandler::WRITE; - if (!ioHandler_.registerHandler(eventFlags_)) { - throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, - withAddr("failed to register AsyncSocket connect handler")); - } + if (tfoEnabled_) { + state_ = StateEnum::FAST_OPEN; + tfoAttempted_ = true; + } else { + if (socketConnect(saddr, addr_.getActualSize()) < 0) { return; - } else { - throw AsyncSocketException( - AsyncSocketException::NOT_OPEN, - "connect failed (immediately)", - errnoCopy); } } @@ -481,10 +458,52 @@ void AsyncSocket::connect(ConnectCallback* callback, VLOG(8) << "AsyncSocket::connect succeeded immediately; this=" << this; assert(readCallback_ == nullptr); assert(writeReqHead_ == nullptr); - state_ = StateEnum::ESTABLISHED; + if (state_ != StateEnum::FAST_OPEN) { + state_ = StateEnum::ESTABLISHED; + } invokeConnectSuccess(); } +int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) { + int rv = ::connect(fd_, saddr, len); + if (rv < 0) { + auto errnoCopy = errno; + if (errnoCopy == EINPROGRESS) { + scheduleConnectTimeoutAndRegisterForEvents(); + } else { + throw AsyncSocketException( + AsyncSocketException::NOT_OPEN, + "connect failed (immediately)", + errnoCopy); + } + } + return rv; +} + +void AsyncSocket::scheduleConnectTimeoutAndRegisterForEvents() { + // Connection in progress. + int timeout = connectTimeout_.count(); + if (timeout > 0) { + // Start a timer in case the connection takes too long. + if (!writeTimeout_.scheduleTimeout(timeout)) { + throw AsyncSocketException( + AsyncSocketException::INTERNAL_ERROR, + withAddr("failed to schedule AsyncSocket connect timeout")); + } + } + + // Register for write events, so we'll + // be notified when the connection finishes/fails. + // Note that we don't register for a persistent event here. + assert(eventFlags_ == EventHandler::NONE); + eventFlags_ = EventHandler::WRITE; + if (!ioHandler_.registerHandler(eventFlags_)) { + throw AsyncSocketException( + AsyncSocketException::INTERNAL_ERROR, + withAddr("failed to register AsyncSocket connect handler")); + } +} + void AsyncSocket::connect(ConnectCallback* callback, const string& ip, uint16_t port, int timeout, @@ -502,7 +521,7 @@ void AsyncSocket::connect(ConnectCallback* callback, void AsyncSocket::cancelConnect() { connectCallback_ = nullptr; - if (state_ == StateEnum::CONNECTING) { + if (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN) { closeNow(); } } @@ -514,7 +533,7 @@ void AsyncSocket::setSendTimeout(uint32_t milliseconds) { // If we are currently pending on write requests, immediately update // writeTimeout_ with the new value. if ((eventFlags_ & EventHandler::WRITE) && - (state_ != StateEnum::CONNECTING)) { + (state_ != StateEnum::CONNECTING && state_ != StateEnum::FAST_OPEN)) { assert(state_ == StateEnum::ESTABLISHED); assert((shutdownFlags_ & SHUT_WRITE) == 0); if (sendTimeout_ > 0) { @@ -573,6 +592,7 @@ void AsyncSocket::setReadCB(ReadCallback *callback) { switch ((StateEnum)state_) { case StateEnum::CONNECTING: + case StateEnum::FAST_OPEN: // For convenience, we allow the read callback to be set while we are // still connecting. We just store the callback for now. Once the // connection completes we'll register for read events. @@ -683,7 +703,8 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec, uint32_t partialWritten = 0; int bytesWritten = 0; bool mustRegister = false; - if (state_ == StateEnum::ESTABLISHED && !connecting()) { + if ((state_ == StateEnum::ESTABLISHED || state_ == StateEnum::FAST_OPEN) && + !connecting()) { if (writeReqHead_ == nullptr) { // If we are established and there are no other writes pending, // we can attempt to perform the write immediately. @@ -715,7 +736,13 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec, bufferCallback_->onEgressBuffered(); } } - mustRegister = true; + if (!connecting()) { + // Writes might put the socket back into connecting state + // if TFO is enabled, and using TFO fails. + // This means that write timeouts would not be active, however + // connect timeouts would affect this stage. + mustRegister = true; + } } } else if (!connecting()) { // Invalid state for writing @@ -839,7 +866,7 @@ void AsyncSocket::closeNow() { switch (state_) { case StateEnum::ESTABLISHED: case StateEnum::CONNECTING: - { + case StateEnum::FAST_OPEN: { shutdownFlags_ |= (SHUT_READ | SHUT_WRITE); state_ = StateEnum::CLOSED; @@ -995,6 +1022,13 @@ void AsyncSocket::shutdownWriteNow() { // immediately shut down the write side of the socket. shutdownFlags_ |= SHUT_WRITE_PENDING; return; + case StateEnum::FAST_OPEN: + // In fast open state we haven't call connected yet, and if we shutdown + // the writes, we will never try to call connect, so shut everything down + shutdownFlags_ |= SHUT_WRITE; + // Immediately fail all write requests + failAllWrites(socketShutdownForWritesEx); + return; case StateEnum::CLOSED: case StateEnum::ERROR: // We should never get here. SHUT_WRITE should always be set @@ -1046,9 +1080,10 @@ bool AsyncSocket::hangup() const { } bool AsyncSocket::good() const { - return ((state_ == StateEnum::CONNECTING || - state_ == StateEnum::ESTABLISHED) && - (shutdownFlags_ == 0) && (eventBase_ != nullptr)); + return ( + (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN || + state_ == StateEnum::ESTABLISHED) && + (shutdownFlags_ == 0) && (eventBase_ != nullptr)); } bool AsyncSocket::error() const { @@ -1695,17 +1730,97 @@ void AsyncSocket::timeoutExpired() noexcept { if (state_ == StateEnum::CONNECTING) { // connect() timed out // Unregister for I/O events. - AsyncSocketException ex(AsyncSocketException::TIMED_OUT, - "connect timed out"); - failConnect(__func__, ex); + if (connectCallback_) { + AsyncSocketException ex( + AsyncSocketException::TIMED_OUT, "connect timed out"); + failConnect(__func__, ex); + } else { + // we faced a connect error without a connect callback, which could + // happen due to TFO. + AsyncSocketException ex( + AsyncSocketException::TIMED_OUT, "write timed out during connection"); + failWrite(__func__, ex); + } } else { // a normal write operation timed out - assert(state_ == StateEnum::ESTABLISHED); AsyncSocketException ex(AsyncSocketException::TIMED_OUT, "write timed out"); failWrite(__func__, ex); } } +ssize_t AsyncSocket::tfoSendMsg(int fd, struct msghdr* msg, int msg_flags) { + return detail::tfo_sendmsg(fd, msg, msg_flags); +} + +AsyncSocket::WriteResult AsyncSocket::sendSocketMessage( + struct msghdr* msg, + int msg_flags) { + ssize_t totalWritten = 0; + if (state_ == StateEnum::FAST_OPEN) { + sockaddr_storage addr; + auto len = addr_.getAddress(&addr); + msg->msg_name = &addr; + msg->msg_namelen = len; + totalWritten = tfoSendMsg(fd_, msg, msg_flags); + if (totalWritten >= 0) { + tfoFinished_ = true; + state_ = StateEnum::ESTABLISHED; + handleInitialReadWrite(); + } else if (errno == EINPROGRESS) { + VLOG(4) << "TFO falling back to connecting"; + // A normal sendmsg doesn't return EINPROGRESS, however + // TFO might fallback to connecting if there is no + // cookie. + state_ = StateEnum::CONNECTING; + try { + scheduleConnectTimeoutAndRegisterForEvents(); + } catch (const AsyncSocketException& ex) { + return WriteResult( + WRITE_ERROR, folly::make_unique(ex)); + } + // Let's fake it that no bytes were written. + // Some clients check errno even if return code is 0, so we + // set it just in case. + errno = EAGAIN; + totalWritten = 0; + } else if (errno == EOPNOTSUPP) { + VLOG(4) << "TFO not supported"; + // Try falling back to connecting. + state_ = StateEnum::CONNECTING; + try { + int ret = socketConnect((const sockaddr*)&addr, len); + if (ret == 0) { + // connect succeeded immediately + // Treat this like no data was written. + state_ = StateEnum::ESTABLISHED; + handleInitialReadWrite(); + } + // If there was no exception during connections, + // we would return that no bytes were written. + // Some clients check errno even if return code is 0, so we + // set it just in case. + errno = EAGAIN; + totalWritten = 0; + } catch (const AsyncSocketException& ex) { + return WriteResult( + WRITE_ERROR, folly::make_unique(ex)); + } + } else if (errno == EAGAIN) { + // Normally sendmsg would indicate that the write would block. + // However in the fast open case, it would indicate that sendmsg + // fell back to a connect. This is a return code from connect() + // instead, and is an error condition indicating no fds available. + return WriteResult( + WRITE_ERROR, + folly::make_unique( + AsyncSocketException::UNKNOWN, "No more free local ports")); + } + } else { + totalWritten = ::sendmsg(fd_, msg, msg_flags); + } + return WriteResult(totalWritten); +} + AsyncSocket::WriteResult AsyncSocket::performWrite( const iovec* vec, uint32_t count, @@ -1740,9 +1855,10 @@ AsyncSocket::WriteResult AsyncSocket::performWrite( // marks that this is the last byte of a record (response) msg_flags |= MSG_EOR; } - ssize_t totalWritten = ::sendmsg(fd_, &msg, msg_flags); + auto writeResult = sendSocketMessage(&msg, msg_flags); + auto totalWritten = writeResult.writeReturn; if (totalWritten < 0) { - if (errno == EAGAIN) { + if (!writeResult.exception && errno == EAGAIN) { // TCP buffer is full; we can't write any more data right now. *countWritten = 0; *partialWritten = 0; @@ -1751,7 +1867,7 @@ AsyncSocket::WriteResult AsyncSocket::performWrite( // error *countWritten = 0; *partialWritten = 0; - return WriteResult(WRITE_ERROR); + return writeResult; } appBytesWritten_ += totalWritten; diff --git a/folly/io/async/AsyncSocket.h b/folly/io/async/AsyncSocket.h index 2a77ac3b..f3e605d2 100644 --- a/folly/io/async/AsyncSocket.h +++ b/folly/io/async/AsyncSocket.h @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -416,6 +417,20 @@ class AsyncSocket : virtual public AsyncTransportWrapper { return connectTimeout_; } + bool getTFOAttempted() const { + return tfoAttempted_; + } + + /** + * Returns whether or not the attempt to use TFO + * finished successfully. This does not necessarily + * mean TFO worked, just that trying to use TFO + * succeeded. + */ + bool getTFOFinished() const { + return tfoFinished_; + } + // Methods controlling socket options /** @@ -509,12 +524,24 @@ class AsyncSocket : virtual public AsyncTransportWrapper { peek_ = peek; } + /** + * Enables TFO behavior on the AsyncSocket if FOLLY_ALLOW_TFO + * is set. + */ + void enableTFO() { + // No-op if folly does not allow tfo +#if FOLLY_ALLOW_TFO + tfoEnabled_ = true; +#endif + } + enum class StateEnum : uint8_t { UNINIT, CONNECTING, ESTABLISHED, CLOSED, - ERROR + ERROR, + FAST_OPEN, }; void setBufferCallback(BufferCallback* cb); @@ -784,6 +811,20 @@ class AsyncSocket : virtual public AsyncTransportWrapper { uint32_t* countWritten, uint32_t* partialWritten); + /** + * Sends the message over the socket using sendmsg + * + * @param msg Message to send + * @param msg_flags Flags to pass to sendmsg + */ + AsyncSocket::WriteResult sendSocketMessage(struct msghdr* msg, int msg_flags); + + virtual ssize_t tfoSendMsg(int fd, struct msghdr* msg, int msg_flags); + + int socketConnect(const struct sockaddr* addr, socklen_t len); + + void scheduleConnectTimeoutAndRegisterForEvents(); + bool updateEventRegistration(); /** @@ -854,6 +895,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper { std::chrono::milliseconds connectTimeout_{0}; BufferCallback* bufferCallback_{nullptr}; + bool tfoEnabled_{false}; + bool tfoAttempted_{false}; + bool tfoFinished_{false}; }; #ifdef _MSC_VER #pragma vtordisp(pop) diff --git a/folly/io/async/test/AsyncSocketTest.h b/folly/io/async/test/AsyncSocketTest.h index a0087c13..549d9c6b 100644 --- a/folly/io/async/test/AsyncSocketTest.h +++ b/folly/io/async/test/AsyncSocketTest.h @@ -72,6 +72,7 @@ class WriteCallback : public folly::AsyncTransportWrapper::WriteCallback { void writeErr(size_t bytesWritten, const folly::AsyncSocketException& ex) noexcept override { + LOG(ERROR) << ex.what(); state = STATE_FAILED; this->bytesWritten = bytesWritten; exception = ex; @@ -205,8 +206,7 @@ class TestServer { public: // Create a TestServer. // This immediately starts listening on an ephemeral port. - TestServer() - : fd_(-1) { + explicit TestServer(bool enableTFO = false) : fd_(-1) { fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP); if (fd_ < 0) { throw folly::AsyncSocketException( @@ -221,6 +221,11 @@ class TestServer { "non-blocking mode", errno); } + if (enableTFO) { +#if FOLLY_ALLOW_TFO + folly::detail::tfo_enable(fd_, 100); +#endif + } if (listen(fd_, 10) != 0) { throw folly::AsyncSocketException( folly::AsyncSocketException::INTERNAL_ERROR, diff --git a/folly/io/async/test/AsyncSocketTest2.cpp b/folly/io/async/test/AsyncSocketTest2.cpp index 23b806ad..ca646031 100644 --- a/folly/io/async/test/AsyncSocketTest2.cpp +++ b/folly/io/async/test/AsyncSocketTest2.cpp @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include +#include +#include +#include #include #include #include #include -#include -#include -#include #include #include @@ -28,11 +29,12 @@ #include #include -#include #include -#include #include +#include +#include #include +#include #include using namespace boost; @@ -47,6 +49,7 @@ using std::chrono::milliseconds; using boost::scoped_array; using namespace folly; +using namespace testing; class DelayedWrite: public AsyncTimeout { public: @@ -100,6 +103,28 @@ TEST(AsyncSocketTest, Connect) { EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30)); } +enum class TFOState { + DISABLED, + ENABLED, +}; + +class AsyncSocketConnectTest : public ::testing::TestWithParam {}; + +std::vector getTestingValues() { + std::vector vals; + vals.emplace_back(TFOState::DISABLED); + +#if FOLLY_ALLOW_TFO + vals.emplace_back(TFOState::ENABLED); +#endif + return vals; +} + +INSTANTIATE_TEST_CASE_P( + ConnectTests, + AsyncSocketConnectTest, + ::testing::ValuesIn(getTestingValues())); + /** * Test connecting to a server that isn't listening */ @@ -115,10 +140,10 @@ TEST(AsyncSocketTest, ConnectRefused) { evb.loop(); - CHECK_EQ(cb.state, STATE_FAILED); - CHECK_EQ(cb.exception.getType(), AsyncSocketException::NOT_OPEN); + EXPECT_EQ(STATE_FAILED, cb.state); + EXPECT_EQ(AsyncSocketException::NOT_OPEN, cb.exception.getType()); EXPECT_LE(0, socket->getConnectTime().count()); - EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30)); + EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout()); } /** @@ -164,12 +189,17 @@ TEST(AsyncSocketTest, ConnectTimeout) { * Test writing immediately after connecting, without waiting for connect * to finish. */ -TEST(AsyncSocketTest, ConnectAndWrite) { +TEST_P(AsyncSocketConnectTest, ConnectAndWrite) { TestServer server; // connect() EventBase evb; std::shared_ptr socket = AsyncSocket::newSocket(&evb); + + if (GetParam() == TFOState::ENABLED) { + socket->enableTFO(); + } + ConnCallback ccb; socket->connect(&ccb, server.getAddress(), 30); @@ -198,12 +228,16 @@ TEST(AsyncSocketTest, ConnectAndWrite) { /** * Test connecting using a nullptr connect callback. */ -TEST(AsyncSocketTest, ConnectNullCallback) { +TEST_P(AsyncSocketConnectTest, ConnectNullCallback) { TestServer server; // connect() EventBase evb; std::shared_ptr socket = AsyncSocket::newSocket(&evb); + if (GetParam() == TFOState::ENABLED) { + socket->enableTFO(); + } + socket->connect(nullptr, server.getAddress(), 30); // write some data, just so we have some way of verifing @@ -231,12 +265,15 @@ TEST(AsyncSocketTest, ConnectNullCallback) { * * This exercises the STATE_CONNECTING_CLOSING code. */ -TEST(AsyncSocketTest, ConnectWriteAndClose) { +TEST_P(AsyncSocketConnectTest, ConnectWriteAndClose) { TestServer server; // connect() EventBase evb; std::shared_ptr socket = AsyncSocket::newSocket(&evb); + if (GetParam() == TFOState::ENABLED) { + socket->enableTFO(); + } ConnCallback ccb; socket->connect(&ccb, server.getAddress(), 30); @@ -374,18 +411,27 @@ TEST(AsyncSocketTest, ConnectWriteAndCloseNow) { /** * Test installing a read callback immediately, before connect() finishes. */ -TEST(AsyncSocketTest, ConnectAndRead) { +TEST_P(AsyncSocketConnectTest, ConnectAndRead) { TestServer server; // connect() EventBase evb; std::shared_ptr socket = AsyncSocket::newSocket(&evb); + if (GetParam() == TFOState::ENABLED) { + socket->enableTFO(); + } + ConnCallback ccb; socket->connect(&ccb, server.getAddress(), 30); ReadCallback rcb; socket->setReadCB(&rcb); + if (GetParam() == TFOState::ENABLED) { + // Trigger a connection + socket->writeChain(nullptr, IOBuf::copyBuffer("hey")); + } + // Even though we haven't looped yet, we should be able to accept // the connection and send data to it. std::shared_ptr acceptedSocket = server.accept(); @@ -399,7 +445,6 @@ TEST(AsyncSocketTest, ConnectAndRead) { evb.loop(); CHECK_EQ(ccb.state, STATE_SUCCEEDED); - CHECK_EQ(rcb.state, STATE_SUCCEEDED); CHECK_EQ(rcb.buffers.size(), 1); CHECK_EQ(rcb.buffers[0].length, sizeof(buf)); CHECK_EQ(memcmp(rcb.buffers[0].buffer, buf, sizeof(buf)), 0); @@ -450,12 +495,15 @@ TEST(AsyncSocketTest, ConnectReadAndClose) { * Test both writing and installing a read callback immediately, * before connect() finishes. */ -TEST(AsyncSocketTest, ConnectWriteAndRead) { +TEST_P(AsyncSocketConnectTest, ConnectWriteAndRead) { TestServer server; // connect() EventBase evb; std::shared_ptr socket = AsyncSocket::newSocket(&evb); + if (GetParam() == TFOState::ENABLED) { + socket->enableTFO(); + } ConnCallback ccb; socket->connect(&ccb, server.getAddress(), 30); @@ -2309,3 +2357,415 @@ TEST(AsyncSocketTest, BufferCallbackKill) { evb.loop(); CHECK_EQ(ccb.state, STATE_SUCCEEDED); } + +#if FOLLY_ALLOW_TFO +TEST(AsyncSocketTest, ConnectTFO) { + // Start listening on a local port + TestServer server(true); + + // Connect using a AsyncSocket + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + socket->enableTFO(); + ConnCallback cb; + socket->connect(&cb, server.getAddress(), 30); + + std::array buf; + memset(buf.data(), 'a', buf.size()); + + std::array readBuf; + auto sendBuf = IOBuf::copyBuffer("hey"); + + std::thread t([&] { + auto acceptedSocket = server.accept(); + acceptedSocket->write(buf.data(), buf.size()); + acceptedSocket->flush(); + acceptedSocket->readAll(readBuf.data(), readBuf.size()); + acceptedSocket->close(); + }); + + evb.loop(); + + CHECK_EQ(cb.state, STATE_SUCCEEDED); + EXPECT_LE(0, socket->getConnectTime().count()); + EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30)); + EXPECT_TRUE(socket->getTFOAttempted()); + + // Should trigger the connect + WriteCallback write; + ReadCallback rcb; + socket->writeChain(&write, sendBuf->clone()); + socket->setReadCB(&rcb); + evb.loop(); + + t.join(); + + EXPECT_EQ(STATE_SUCCEEDED, write.state); + EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size())); + EXPECT_EQ(STATE_SUCCEEDED, rcb.state); + ASSERT_EQ(1, rcb.buffers.size()); + ASSERT_EQ(sizeof(buf), rcb.buffers[0].length); + EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size())); +} + +/** + * Test connecting to a server that isn't listening + */ +TEST(AsyncSocketTest, ConnectRefusedTFO) { + EventBase evb; + + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + + socket->enableTFO(); + + // Hopefully nothing is actually listening on this address + folly::SocketAddress addr("::1", 65535); + ConnCallback cb; + socket->connect(&cb, addr, 30); + + evb.loop(); + + WriteCallback write1; + // Trigger the connect if TFO attempt is supported. + socket->writeChain(&write1, IOBuf::copyBuffer("hey")); + evb.loop(); + WriteCallback write2; + socket->writeChain(&write2, IOBuf::copyBuffer("hey")); + evb.loop(); + + if (!socket->getTFOFinished()) { + EXPECT_EQ(STATE_FAILED, write1.state); + EXPECT_FALSE(socket->getTFOFinished()); + } else { + EXPECT_EQ(STATE_SUCCEEDED, write1.state); + EXPECT_TRUE(socket->getTFOFinished()); + } + + EXPECT_EQ(STATE_FAILED, write2.state); + + EXPECT_EQ(STATE_SUCCEEDED, cb.state); + EXPECT_LE(0, socket->getConnectTime().count()); + EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout()); + EXPECT_TRUE(socket->getTFOAttempted()); +} + +/** + * Test calling closeNow() immediately after connecting. + */ +TEST(AsyncSocketTest, ConnectWriteAndCloseNowTFO) { + TestServer server(true); + + // connect() + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + socket->enableTFO(); + + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + // write() + std::array buf; + memset(buf.data(), 'a', buf.size()); + + // close() + socket->closeNow(); + + // Loop, although there shouldn't be anything to do. + evb.loop(); + + CHECK_EQ(ccb.state, STATE_SUCCEEDED); + + ASSERT_TRUE(socket->isClosedBySelf()); + ASSERT_FALSE(socket->isClosedByPeer()); +} + +/** + * Test calling close() immediately after connect() + */ +TEST(AsyncSocketTest, ConnectAndCloseTFO) { + TestServer server(true); + + // Connect using a AsyncSocket + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + socket->enableTFO(); + + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + socket->close(); + + // Loop, although there shouldn't be anything to do. + evb.loop(); + + // Make sure the connection was aborted + CHECK_EQ(ccb.state, STATE_SUCCEEDED); + + ASSERT_TRUE(socket->isClosedBySelf()); + ASSERT_FALSE(socket->isClosedByPeer()); +} + +class MockAsyncTFOSocket : public AsyncSocket { + public: + using UniquePtr = std::unique_ptr; + + explicit MockAsyncTFOSocket(EventBase* evb) : AsyncSocket(evb) {} + + MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags)); +}; + +TEST(AsyncSocketTest, TestTFOUnsupported) { + TestServer server(true); + + // Connect using a AsyncSocket + EventBase evb; + auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb)); + socket->enableTFO(); + + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + CHECK_EQ(ccb.state, STATE_SUCCEEDED); + + ReadCallback rcb; + socket->setReadCB(&rcb); + + EXPECT_CALL(*socket, tfoSendMsg(_, _, _)) + .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1)); + WriteCallback write; + auto sendBuf = IOBuf::copyBuffer("hey"); + socket->writeChain(&write, sendBuf->clone()); + EXPECT_EQ(STATE_WAITING, write.state); + + std::array buf; + memset(buf.data(), 'a', buf.size()); + + std::array readBuf; + + std::thread t([&] { + std::shared_ptr acceptedSocket = server.accept(); + acceptedSocket->write(buf.data(), buf.size()); + acceptedSocket->flush(); + acceptedSocket->readAll(readBuf.data(), readBuf.size()); + acceptedSocket->close(); + }); + + evb.loop(); + + t.join(); + EXPECT_EQ(STATE_SUCCEEDED, ccb.state); + EXPECT_EQ(STATE_SUCCEEDED, write.state); + + EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size())); + EXPECT_EQ(STATE_SUCCEEDED, rcb.state); + ASSERT_EQ(1, rcb.buffers.size()); + ASSERT_EQ(sizeof(buf), rcb.buffers[0].length); + EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size())); +} + +TEST(AsyncSocketTest, TestTFOUnsupportedTimeout) { + // Try connecting to server that won't respond. + // + // This depends somewhat on the network where this test is run. + // Hopefully this IP will be routable but unresponsive. + // (Alternatively, we could try listening on a local raw socket, but that + // normally requires root privileges.) + auto host = SocketAddressTestHelper::isIPv6Enabled() + ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6 + : SocketAddressTestHelper::isIPv4Enabled() + ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4 + : nullptr; + SocketAddress addr(host, 65535); + + // Connect using a AsyncSocket + EventBase evb; + auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb)); + socket->enableTFO(); + + ConnCallback ccb; + // Set a very small timeout + socket->connect(&ccb, addr, 1); + EXPECT_EQ(STATE_SUCCEEDED, ccb.state); + + ReadCallback rcb; + socket->setReadCB(&rcb); + + EXPECT_CALL(*socket, tfoSendMsg(_, _, _)) + .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1)); + WriteCallback write; + socket->writeChain(&write, IOBuf::copyBuffer("hey")); + + evb.loop(); + + EXPECT_EQ(STATE_FAILED, write.state); +} + +TEST(AsyncSocketTest, TestTFOFallbackToConnect) { + TestServer server(true); + + // Connect using a AsyncSocket + EventBase evb; + auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb)); + socket->enableTFO(); + + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + CHECK_EQ(ccb.state, STATE_SUCCEEDED); + + ReadCallback rcb; + socket->setReadCB(&rcb); + + EXPECT_CALL(*socket, tfoSendMsg(_, _, _)) + .WillOnce(Invoke([&](int fd, struct msghdr*, int) { + sockaddr_storage addr; + auto len = server.getAddress().getAddress(&addr); + return connect(fd, (const struct sockaddr*)&addr, len); + })); + WriteCallback write; + auto sendBuf = IOBuf::copyBuffer("hey"); + socket->writeChain(&write, sendBuf->clone()); + EXPECT_EQ(STATE_WAITING, write.state); + + std::array buf; + memset(buf.data(), 'a', buf.size()); + + std::array readBuf; + + std::thread t([&] { + std::shared_ptr acceptedSocket = server.accept(); + acceptedSocket->write(buf.data(), buf.size()); + acceptedSocket->flush(); + acceptedSocket->readAll(readBuf.data(), readBuf.size()); + acceptedSocket->close(); + }); + + evb.loop(); + + t.join(); + EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size())); + + EXPECT_EQ(STATE_SUCCEEDED, ccb.state); + EXPECT_EQ(STATE_SUCCEEDED, write.state); + + EXPECT_EQ(STATE_SUCCEEDED, rcb.state); + ASSERT_EQ(1, rcb.buffers.size()); + ASSERT_EQ(buf.size(), rcb.buffers[0].length); + EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size())); +} + +TEST(AsyncSocketTest, TestTFOFallbackTimeout) { + // Try connecting to server that won't respond. + // + // This depends somewhat on the network where this test is run. + // Hopefully this IP will be routable but unresponsive. + // (Alternatively, we could try listening on a local raw socket, but that + // normally requires root privileges.) + auto host = SocketAddressTestHelper::isIPv6Enabled() + ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6 + : SocketAddressTestHelper::isIPv4Enabled() + ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4 + : nullptr; + SocketAddress addr(host, 65535); + + // Connect using a AsyncSocket + EventBase evb; + auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb)); + socket->enableTFO(); + + ConnCallback ccb; + // Set a very small timeout + socket->connect(&ccb, addr, 1); + EXPECT_EQ(STATE_SUCCEEDED, ccb.state); + + ReadCallback rcb; + socket->setReadCB(&rcb); + + EXPECT_CALL(*socket, tfoSendMsg(_, _, _)) + .WillOnce(Invoke([&](int fd, struct msghdr*, int) { + sockaddr_storage addr2; + auto len = addr.getAddress(&addr2); + return connect(fd, (const struct sockaddr*)&addr2, len); + })); + WriteCallback write; + socket->writeChain(&write, IOBuf::copyBuffer("hey")); + + evb.loop(); + + EXPECT_EQ(STATE_FAILED, write.state); +} + +TEST(AsyncSocketTest, TestTFOEagain) { + TestServer server(true); + + // Connect using a AsyncSocket + EventBase evb; + auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb)); + socket->enableTFO(); + + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + EXPECT_CALL(*socket, tfoSendMsg(_, _, _)) + .WillOnce(SetErrnoAndReturn(EAGAIN, -1)); + WriteCallback write; + socket->writeChain(&write, IOBuf::copyBuffer("hey")); + + evb.loop(); + + EXPECT_EQ(STATE_SUCCEEDED, ccb.state); + EXPECT_EQ(STATE_FAILED, write.state); +} + +// Sending a large amount of data in the first write which will +// definitely not fit into MSS. +TEST(AsyncSocketTest, ConnectTFOWithBigData) { + // Start listening on a local port + TestServer server(true); + + // Connect using a AsyncSocket + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + socket->enableTFO(); + ConnCallback cb; + socket->connect(&cb, server.getAddress(), 30); + + std::array buf; + memset(buf.data(), 'a', buf.size()); + + constexpr size_t len = 10 * 1024; + auto sendBuf = IOBuf::create(len); + sendBuf->append(len); + std::array readBuf; + + std::thread t([&] { + auto acceptedSocket = server.accept(); + acceptedSocket->write(buf.data(), buf.size()); + acceptedSocket->flush(); + acceptedSocket->readAll(readBuf.data(), readBuf.size()); + acceptedSocket->close(); + }); + + evb.loop(); + + CHECK_EQ(cb.state, STATE_SUCCEEDED); + EXPECT_LE(0, socket->getConnectTime().count()); + EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30)); + EXPECT_TRUE(socket->getTFOAttempted()); + + // Should trigger the connect + WriteCallback write; + ReadCallback rcb; + socket->writeChain(&write, sendBuf->clone()); + socket->setReadCB(&rcb); + evb.loop(); + + t.join(); + + EXPECT_EQ(STATE_SUCCEEDED, write.state); + EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size())); + EXPECT_EQ(STATE_SUCCEEDED, rcb.state); + ASSERT_EQ(1, rcb.buffers.size()); + ASSERT_EQ(sizeof(buf), rcb.buffers[0].length); + EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size())); +} + +#endif diff --git a/folly/io/async/test/BlockingSocket.h b/folly/io/async/test/BlockingSocket.h index 360cfcb1..3830648e 100644 --- a/folly/io/async/test/BlockingSocket.h +++ b/folly/io/async/test/BlockingSocket.h @@ -40,6 +40,10 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback, sock_->attachEventBase(&eventBase_); } + void setAddress(folly::SocketAddress address) { + address_ = address; + } + void open() { sock_->connect(this, address_); eventBase_.loop(); @@ -110,11 +114,15 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback, } int32_t readHelper(uint8_t *buf, size_t len, bool all) { + if (!sock_->good()) { + return 0; + } + readBuf_ = buf; readLen_ = len; sock_->setReadCB(this); while (!err_ && sock_->good() && readLen_ > 0) { - eventBase_.loop(); + eventBase_.loopOnce(); if (!all) { break; } diff --git a/folly/io/async/test/SocketClient.cpp b/folly/io/async/test/SocketClient.cpp new file mode 100644 index 00000000..7f20d480 --- /dev/null +++ b/folly/io/async/test/SocketClient.cpp @@ -0,0 +1,70 @@ +/* + * Copyright 2016 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include + +using namespace folly; + +DEFINE_string(host, "localhost", "Host"); +DEFINE_int32(port, 0, "port"); +DEFINE_bool(tfo, false, "enable tfo"); +DEFINE_string(msg, "", "Message to send"); + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + if (FLAGS_port == 0) { + LOG(ERROR) << "Must specify port"; + exit(EXIT_FAILURE); + } + + // Prep the socket + EventBase evb; + AsyncSocket::UniquePtr socket(new AsyncSocket(&evb)); + socket->detachEventBase(); + + if (FLAGS_tfo) { +#if FOLLY_ALLOW_TFO + socket->enableTFO(); +#endif + } + + // Keep this around + auto sockAddr = socket.get(); + + BlockingSocket sock(std::move(socket)); + SocketAddress addr; + addr.setFromHostPort(FLAGS_host, FLAGS_port); + sock.setAddress(addr); + sock.open(); + LOG(INFO) << "connected to " << addr.getAddressStr(); + + sock.write((const uint8_t*)FLAGS_msg.data(), FLAGS_msg.size()); + + LOG(ERROR) << "TFO attempted: " << sockAddr->getTFOAttempted(); + LOG(ERROR) << "TFO finished: " << sockAddr->getTFOFinished(); + + std::array buf; + int32_t bytesRead = 0; + while ((bytesRead = sock.read((uint8_t*)buf.data(), buf.size())) != 0) { + std::cout << std::string(buf.data(), bytesRead); + } + + sock.close(); + return 0; +} -- 2.34.1