#include <folly/io/async/AsyncSocket.h>
-#include <folly/io/async/EventBase.h>
-#include <folly/io/async/EventHandler.h>
+#include <folly/ExceptionWrapper.h>
#include <folly/SocketAddress.h>
#include <folly/io/IOBuf.h>
#include <folly/portability/Fcntl.h>
// Perform the connect()
address.getAddress(&addrStorage);
- rv = ::connect(fd_, saddr, address.getActualSize());
- if (rv < 0) {
- auto errnoCopy = errno;
- if (errnoCopy == EINPROGRESS) {
- // Connection in progress.
- if (timeout > 0) {
- // Start a timer in case the connection takes too long.
- if (!writeTimeout_.scheduleTimeout(timeout)) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- withAddr("failed to schedule AsyncSocket connect timeout"));
- }
- }
-
- // Register for write events, so we'll
- // be notified when the connection finishes/fails.
- // Note that we don't register for a persistent event here.
- assert(eventFlags_ == EventHandler::NONE);
- eventFlags_ = EventHandler::WRITE;
- if (!ioHandler_.registerHandler(eventFlags_)) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- withAddr("failed to register AsyncSocket connect handler"));
- }
+ if (tfoEnabled_) {
+ state_ = StateEnum::FAST_OPEN;
+ tfoAttempted_ = true;
+ } else {
+ if (socketConnect(saddr, addr_.getActualSize()) < 0) {
return;
- } else {
- throw AsyncSocketException(
- AsyncSocketException::NOT_OPEN,
- "connect failed (immediately)",
- errnoCopy);
}
}
VLOG(8) << "AsyncSocket::connect succeeded immediately; this=" << this;
assert(readCallback_ == nullptr);
assert(writeReqHead_ == nullptr);
- state_ = StateEnum::ESTABLISHED;
+ if (state_ != StateEnum::FAST_OPEN) {
+ state_ = StateEnum::ESTABLISHED;
+ }
invokeConnectSuccess();
}
+int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
+ int rv = ::connect(fd_, saddr, len);
+ if (rv < 0) {
+ auto errnoCopy = errno;
+ if (errnoCopy == EINPROGRESS) {
+ scheduleConnectTimeoutAndRegisterForEvents();
+ } else {
+ throw AsyncSocketException(
+ AsyncSocketException::NOT_OPEN,
+ "connect failed (immediately)",
+ errnoCopy);
+ }
+ }
+ return rv;
+}
+
+void AsyncSocket::scheduleConnectTimeoutAndRegisterForEvents() {
+ // Connection in progress.
+ int timeout = connectTimeout_.count();
+ if (timeout > 0) {
+ // Start a timer in case the connection takes too long.
+ if (!writeTimeout_.scheduleTimeout(timeout)) {
+ throw AsyncSocketException(
+ AsyncSocketException::INTERNAL_ERROR,
+ withAddr("failed to schedule AsyncSocket connect timeout"));
+ }
+ }
+
+ // Register for write events, so we'll
+ // be notified when the connection finishes/fails.
+ // Note that we don't register for a persistent event here.
+ assert(eventFlags_ == EventHandler::NONE);
+ eventFlags_ = EventHandler::WRITE;
+ if (!ioHandler_.registerHandler(eventFlags_)) {
+ throw AsyncSocketException(
+ AsyncSocketException::INTERNAL_ERROR,
+ withAddr("failed to register AsyncSocket connect handler"));
+ }
+}
+
void AsyncSocket::connect(ConnectCallback* callback,
const string& ip, uint16_t port,
int timeout,
void AsyncSocket::cancelConnect() {
connectCallback_ = nullptr;
- if (state_ == StateEnum::CONNECTING) {
+ if (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN) {
closeNow();
}
}
// If we are currently pending on write requests, immediately update
// writeTimeout_ with the new value.
if ((eventFlags_ & EventHandler::WRITE) &&
- (state_ != StateEnum::CONNECTING)) {
+ (state_ != StateEnum::CONNECTING && state_ != StateEnum::FAST_OPEN)) {
assert(state_ == StateEnum::ESTABLISHED);
assert((shutdownFlags_ & SHUT_WRITE) == 0);
if (sendTimeout_ > 0) {
switch ((StateEnum)state_) {
case StateEnum::CONNECTING:
+ case StateEnum::FAST_OPEN:
// For convenience, we allow the read callback to be set while we are
// still connecting. We just store the callback for now. Once the
// connection completes we'll register for read events.
uint32_t partialWritten = 0;
int bytesWritten = 0;
bool mustRegister = false;
- if (state_ == StateEnum::ESTABLISHED && !connecting()) {
+ if ((state_ == StateEnum::ESTABLISHED || state_ == StateEnum::FAST_OPEN) &&
+ !connecting()) {
if (writeReqHead_ == nullptr) {
// If we are established and there are no other writes pending,
// we can attempt to perform the write immediately.
bufferCallback_->onEgressBuffered();
}
}
- mustRegister = true;
+ if (!connecting()) {
+ // Writes might put the socket back into connecting state
+ // if TFO is enabled, and using TFO fails.
+ // This means that write timeouts would not be active, however
+ // connect timeouts would affect this stage.
+ mustRegister = true;
+ }
}
} else if (!connecting()) {
// Invalid state for writing
switch (state_) {
case StateEnum::ESTABLISHED:
case StateEnum::CONNECTING:
- {
+ case StateEnum::FAST_OPEN: {
shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
state_ = StateEnum::CLOSED;
// immediately shut down the write side of the socket.
shutdownFlags_ |= SHUT_WRITE_PENDING;
return;
+ case StateEnum::FAST_OPEN:
+ // In fast open state we haven't call connected yet, and if we shutdown
+ // the writes, we will never try to call connect, so shut everything down
+ shutdownFlags_ |= SHUT_WRITE;
+ // Immediately fail all write requests
+ failAllWrites(socketShutdownForWritesEx);
+ return;
case StateEnum::CLOSED:
case StateEnum::ERROR:
// We should never get here. SHUT_WRITE should always be set
}
bool AsyncSocket::good() const {
- return ((state_ == StateEnum::CONNECTING ||
- state_ == StateEnum::ESTABLISHED) &&
- (shutdownFlags_ == 0) && (eventBase_ != nullptr));
+ return (
+ (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN ||
+ state_ == StateEnum::ESTABLISHED) &&
+ (shutdownFlags_ == 0) && (eventBase_ != nullptr));
}
bool AsyncSocket::error() const {
if (state_ == StateEnum::CONNECTING) {
// connect() timed out
// Unregister for I/O events.
- AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
- "connect timed out");
- failConnect(__func__, ex);
+ if (connectCallback_) {
+ AsyncSocketException ex(
+ AsyncSocketException::TIMED_OUT, "connect timed out");
+ failConnect(__func__, ex);
+ } else {
+ // we faced a connect error without a connect callback, which could
+ // happen due to TFO.
+ AsyncSocketException ex(
+ AsyncSocketException::TIMED_OUT, "write timed out during connection");
+ failWrite(__func__, ex);
+ }
} else {
// a normal write operation timed out
- assert(state_ == StateEnum::ESTABLISHED);
AsyncSocketException ex(AsyncSocketException::TIMED_OUT, "write timed out");
failWrite(__func__, ex);
}
}
+ssize_t AsyncSocket::tfoSendMsg(int fd, struct msghdr* msg, int msg_flags) {
+ return detail::tfo_sendmsg(fd, msg, msg_flags);
+}
+
+AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
+ struct msghdr* msg,
+ int msg_flags) {
+ ssize_t totalWritten = 0;
+ if (state_ == StateEnum::FAST_OPEN) {
+ sockaddr_storage addr;
+ auto len = addr_.getAddress(&addr);
+ msg->msg_name = &addr;
+ msg->msg_namelen = len;
+ totalWritten = tfoSendMsg(fd_, msg, msg_flags);
+ if (totalWritten >= 0) {
+ tfoFinished_ = true;
+ state_ = StateEnum::ESTABLISHED;
+ handleInitialReadWrite();
+ } else if (errno == EINPROGRESS) {
+ VLOG(4) << "TFO falling back to connecting";
+ // A normal sendmsg doesn't return EINPROGRESS, however
+ // TFO might fallback to connecting if there is no
+ // cookie.
+ state_ = StateEnum::CONNECTING;
+ try {
+ scheduleConnectTimeoutAndRegisterForEvents();
+ } catch (const AsyncSocketException& ex) {
+ return WriteResult(
+ WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
+ }
+ // Let's fake it that no bytes were written.
+ // Some clients check errno even if return code is 0, so we
+ // set it just in case.
+ errno = EAGAIN;
+ totalWritten = 0;
+ } else if (errno == EOPNOTSUPP) {
+ VLOG(4) << "TFO not supported";
+ // Try falling back to connecting.
+ state_ = StateEnum::CONNECTING;
+ try {
+ int ret = socketConnect((const sockaddr*)&addr, len);
+ if (ret == 0) {
+ // connect succeeded immediately
+ // Treat this like no data was written.
+ state_ = StateEnum::ESTABLISHED;
+ handleInitialReadWrite();
+ }
+ // If there was no exception during connections,
+ // we would return that no bytes were written.
+ // Some clients check errno even if return code is 0, so we
+ // set it just in case.
+ errno = EAGAIN;
+ totalWritten = 0;
+ } catch (const AsyncSocketException& ex) {
+ return WriteResult(
+ WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
+ }
+ } else if (errno == EAGAIN) {
+ // Normally sendmsg would indicate that the write would block.
+ // However in the fast open case, it would indicate that sendmsg
+ // fell back to a connect. This is a return code from connect()
+ // instead, and is an error condition indicating no fds available.
+ return WriteResult(
+ WRITE_ERROR,
+ folly::make_unique<AsyncSocketException>(
+ AsyncSocketException::UNKNOWN, "No more free local ports"));
+ }
+ } else {
+ totalWritten = ::sendmsg(fd_, msg, msg_flags);
+ }
+ return WriteResult(totalWritten);
+}
+
AsyncSocket::WriteResult AsyncSocket::performWrite(
const iovec* vec,
uint32_t count,
// marks that this is the last byte of a record (response)
msg_flags |= MSG_EOR;
}
- ssize_t totalWritten = ::sendmsg(fd_, &msg, msg_flags);
+ auto writeResult = sendSocketMessage(&msg, msg_flags);
+ auto totalWritten = writeResult.writeReturn;
if (totalWritten < 0) {
- if (errno == EAGAIN) {
+ if (!writeResult.exception && errno == EAGAIN) {
// TCP buffer is full; we can't write any more data right now.
*countWritten = 0;
*partialWritten = 0;
// error
*countWritten = 0;
*partialWritten = 0;
- return WriteResult(WRITE_ERROR);
+ return writeResult;
}
appBytesWritten_ += totalWritten;
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+#include <folly/ExceptionWrapper.h>
+#include <folly/RWSpinLock.h>
+#include <folly/Random.h>
+#include <folly/SocketAddress.h>
#include <folly/io/async/AsyncServerSocket.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/EventBase.h>
-#include <folly/RWSpinLock.h>
-#include <folly/SocketAddress.h>
-#include <folly/Random.h>
#include <folly/io/IOBuf.h>
#include <folly/io/async/test/AsyncSocketTest.h>
#include <folly/portability/Unistd.h>
#include <folly/test/SocketAddressTestHelper.h>
-#include <gtest/gtest.h>
#include <boost/scoped_array.hpp>
-#include <iostream>
#include <fcntl.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
#include <sys/types.h>
+#include <iostream>
#include <thread>
using namespace boost;
using boost::scoped_array;
using namespace folly;
+using namespace testing;
class DelayedWrite: public AsyncTimeout {
public:
EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
}
+enum class TFOState {
+ DISABLED,
+ ENABLED,
+};
+
+class AsyncSocketConnectTest : public ::testing::TestWithParam<TFOState> {};
+
+std::vector<TFOState> getTestingValues() {
+ std::vector<TFOState> vals;
+ vals.emplace_back(TFOState::DISABLED);
+
+#if FOLLY_ALLOW_TFO
+ vals.emplace_back(TFOState::ENABLED);
+#endif
+ return vals;
+}
+
+INSTANTIATE_TEST_CASE_P(
+ ConnectTests,
+ AsyncSocketConnectTest,
+ ::testing::ValuesIn(getTestingValues()));
+
/**
* Test connecting to a server that isn't listening
*/
evb.loop();
- CHECK_EQ(cb.state, STATE_FAILED);
- CHECK_EQ(cb.exception.getType(), AsyncSocketException::NOT_OPEN);
+ EXPECT_EQ(STATE_FAILED, cb.state);
+ EXPECT_EQ(AsyncSocketException::NOT_OPEN, cb.exception.getType());
EXPECT_LE(0, socket->getConnectTime().count());
- EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
+ EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
}
/**
* Test writing immediately after connecting, without waiting for connect
* to finish.
*/
-TEST(AsyncSocketTest, ConnectAndWrite) {
+TEST_P(AsyncSocketConnectTest, ConnectAndWrite) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+
+ if (GetParam() == TFOState::ENABLED) {
+ socket->enableTFO();
+ }
+
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
/**
* Test connecting using a nullptr connect callback.
*/
-TEST(AsyncSocketTest, ConnectNullCallback) {
+TEST_P(AsyncSocketConnectTest, ConnectNullCallback) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ if (GetParam() == TFOState::ENABLED) {
+ socket->enableTFO();
+ }
+
socket->connect(nullptr, server.getAddress(), 30);
// write some data, just so we have some way of verifing
*
* This exercises the STATE_CONNECTING_CLOSING code.
*/
-TEST(AsyncSocketTest, ConnectWriteAndClose) {
+TEST_P(AsyncSocketConnectTest, ConnectWriteAndClose) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ if (GetParam() == TFOState::ENABLED) {
+ socket->enableTFO();
+ }
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
/**
* Test installing a read callback immediately, before connect() finishes.
*/
-TEST(AsyncSocketTest, ConnectAndRead) {
+TEST_P(AsyncSocketConnectTest, ConnectAndRead) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ if (GetParam() == TFOState::ENABLED) {
+ socket->enableTFO();
+ }
+
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
ReadCallback rcb;
socket->setReadCB(&rcb);
+ if (GetParam() == TFOState::ENABLED) {
+ // Trigger a connection
+ socket->writeChain(nullptr, IOBuf::copyBuffer("hey"));
+ }
+
// Even though we haven't looped yet, we should be able to accept
// the connection and send data to it.
std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
evb.loop();
CHECK_EQ(ccb.state, STATE_SUCCEEDED);
- CHECK_EQ(rcb.state, STATE_SUCCEEDED);
CHECK_EQ(rcb.buffers.size(), 1);
CHECK_EQ(rcb.buffers[0].length, sizeof(buf));
CHECK_EQ(memcmp(rcb.buffers[0].buffer, buf, sizeof(buf)), 0);
* Test both writing and installing a read callback immediately,
* before connect() finishes.
*/
-TEST(AsyncSocketTest, ConnectWriteAndRead) {
+TEST_P(AsyncSocketConnectTest, ConnectWriteAndRead) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ if (GetParam() == TFOState::ENABLED) {
+ socket->enableTFO();
+ }
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
evb.loop();
CHECK_EQ(ccb.state, STATE_SUCCEEDED);
}
+
+#if FOLLY_ALLOW_TFO
+TEST(AsyncSocketTest, ConnectTFO) {
+ // Start listening on a local port
+ TestServer server(true);
+
+ // Connect using a AsyncSocket
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ socket->enableTFO();
+ ConnCallback cb;
+ socket->connect(&cb, server.getAddress(), 30);
+
+ std::array<uint8_t, 128> buf;
+ memset(buf.data(), 'a', buf.size());
+
+ std::array<uint8_t, 3> readBuf;
+ auto sendBuf = IOBuf::copyBuffer("hey");
+
+ std::thread t([&] {
+ auto acceptedSocket = server.accept();
+ acceptedSocket->write(buf.data(), buf.size());
+ acceptedSocket->flush();
+ acceptedSocket->readAll(readBuf.data(), readBuf.size());
+ acceptedSocket->close();
+ });
+
+ evb.loop();
+
+ CHECK_EQ(cb.state, STATE_SUCCEEDED);
+ EXPECT_LE(0, socket->getConnectTime().count());
+ EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
+ EXPECT_TRUE(socket->getTFOAttempted());
+
+ // Should trigger the connect
+ WriteCallback write;
+ ReadCallback rcb;
+ socket->writeChain(&write, sendBuf->clone());
+ socket->setReadCB(&rcb);
+ evb.loop();
+
+ t.join();
+
+ EXPECT_EQ(STATE_SUCCEEDED, write.state);
+ EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
+ EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
+ ASSERT_EQ(1, rcb.buffers.size());
+ ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
+ EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
+}
+
+/**
+ * Test connecting to a server that isn't listening
+ */
+TEST(AsyncSocketTest, ConnectRefusedTFO) {
+ EventBase evb;
+
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+
+ socket->enableTFO();
+
+ // Hopefully nothing is actually listening on this address
+ folly::SocketAddress addr("::1", 65535);
+ ConnCallback cb;
+ socket->connect(&cb, addr, 30);
+
+ evb.loop();
+
+ WriteCallback write1;
+ // Trigger the connect if TFO attempt is supported.
+ socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
+ evb.loop();
+ WriteCallback write2;
+ socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
+ evb.loop();
+
+ if (!socket->getTFOFinished()) {
+ EXPECT_EQ(STATE_FAILED, write1.state);
+ EXPECT_FALSE(socket->getTFOFinished());
+ } else {
+ EXPECT_EQ(STATE_SUCCEEDED, write1.state);
+ EXPECT_TRUE(socket->getTFOFinished());
+ }
+
+ EXPECT_EQ(STATE_FAILED, write2.state);
+
+ EXPECT_EQ(STATE_SUCCEEDED, cb.state);
+ EXPECT_LE(0, socket->getConnectTime().count());
+ EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
+ EXPECT_TRUE(socket->getTFOAttempted());
+}
+
+/**
+ * Test calling closeNow() immediately after connecting.
+ */
+TEST(AsyncSocketTest, ConnectWriteAndCloseNowTFO) {
+ TestServer server(true);
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ socket->enableTFO();
+
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ // write()
+ std::array<char, 128> buf;
+ memset(buf.data(), 'a', buf.size());
+
+ // close()
+ socket->closeNow();
+
+ // Loop, although there shouldn't be anything to do.
+ evb.loop();
+
+ CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+
+ ASSERT_TRUE(socket->isClosedBySelf());
+ ASSERT_FALSE(socket->isClosedByPeer());
+}
+
+/**
+ * Test calling close() immediately after connect()
+ */
+TEST(AsyncSocketTest, ConnectAndCloseTFO) {
+ TestServer server(true);
+
+ // Connect using a AsyncSocket
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ socket->enableTFO();
+
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ socket->close();
+
+ // Loop, although there shouldn't be anything to do.
+ evb.loop();
+
+ // Make sure the connection was aborted
+ CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+
+ ASSERT_TRUE(socket->isClosedBySelf());
+ ASSERT_FALSE(socket->isClosedByPeer());
+}
+
+class MockAsyncTFOSocket : public AsyncSocket {
+ public:
+ using UniquePtr = std::unique_ptr<MockAsyncTFOSocket, Destructor>;
+
+ explicit MockAsyncTFOSocket(EventBase* evb) : AsyncSocket(evb) {}
+
+ MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
+};
+
+TEST(AsyncSocketTest, TestTFOUnsupported) {
+ TestServer server(true);
+
+ // Connect using a AsyncSocket
+ EventBase evb;
+ auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
+ socket->enableTFO();
+
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+ CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+
+ ReadCallback rcb;
+ socket->setReadCB(&rcb);
+
+ EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+ .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1));
+ WriteCallback write;
+ auto sendBuf = IOBuf::copyBuffer("hey");
+ socket->writeChain(&write, sendBuf->clone());
+ EXPECT_EQ(STATE_WAITING, write.state);
+
+ std::array<uint8_t, 128> buf;
+ memset(buf.data(), 'a', buf.size());
+
+ std::array<uint8_t, 3> readBuf;
+
+ std::thread t([&] {
+ std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+ acceptedSocket->write(buf.data(), buf.size());
+ acceptedSocket->flush();
+ acceptedSocket->readAll(readBuf.data(), readBuf.size());
+ acceptedSocket->close();
+ });
+
+ evb.loop();
+
+ t.join();
+ EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
+ EXPECT_EQ(STATE_SUCCEEDED, write.state);
+
+ EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
+ EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
+ ASSERT_EQ(1, rcb.buffers.size());
+ ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
+ EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
+}
+
+TEST(AsyncSocketTest, TestTFOUnsupportedTimeout) {
+ // Try connecting to server that won't respond.
+ //
+ // This depends somewhat on the network where this test is run.
+ // Hopefully this IP will be routable but unresponsive.
+ // (Alternatively, we could try listening on a local raw socket, but that
+ // normally requires root privileges.)
+ auto host = SocketAddressTestHelper::isIPv6Enabled()
+ ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
+ : SocketAddressTestHelper::isIPv4Enabled()
+ ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
+ : nullptr;
+ SocketAddress addr(host, 65535);
+
+ // Connect using a AsyncSocket
+ EventBase evb;
+ auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
+ socket->enableTFO();
+
+ ConnCallback ccb;
+ // Set a very small timeout
+ socket->connect(&ccb, addr, 1);
+ EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
+
+ ReadCallback rcb;
+ socket->setReadCB(&rcb);
+
+ EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+ .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1));
+ WriteCallback write;
+ socket->writeChain(&write, IOBuf::copyBuffer("hey"));
+
+ evb.loop();
+
+ EXPECT_EQ(STATE_FAILED, write.state);
+}
+
+TEST(AsyncSocketTest, TestTFOFallbackToConnect) {
+ TestServer server(true);
+
+ // Connect using a AsyncSocket
+ EventBase evb;
+ auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
+ socket->enableTFO();
+
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+ CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+
+ ReadCallback rcb;
+ socket->setReadCB(&rcb);
+
+ EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+ .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
+ sockaddr_storage addr;
+ auto len = server.getAddress().getAddress(&addr);
+ return connect(fd, (const struct sockaddr*)&addr, len);
+ }));
+ WriteCallback write;
+ auto sendBuf = IOBuf::copyBuffer("hey");
+ socket->writeChain(&write, sendBuf->clone());
+ EXPECT_EQ(STATE_WAITING, write.state);
+
+ std::array<uint8_t, 128> buf;
+ memset(buf.data(), 'a', buf.size());
+
+ std::array<uint8_t, 3> readBuf;
+
+ std::thread t([&] {
+ std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+ acceptedSocket->write(buf.data(), buf.size());
+ acceptedSocket->flush();
+ acceptedSocket->readAll(readBuf.data(), readBuf.size());
+ acceptedSocket->close();
+ });
+
+ evb.loop();
+
+ t.join();
+ EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
+
+ EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
+ EXPECT_EQ(STATE_SUCCEEDED, write.state);
+
+ EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
+ ASSERT_EQ(1, rcb.buffers.size());
+ ASSERT_EQ(buf.size(), rcb.buffers[0].length);
+ EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
+}
+
+TEST(AsyncSocketTest, TestTFOFallbackTimeout) {
+ // Try connecting to server that won't respond.
+ //
+ // This depends somewhat on the network where this test is run.
+ // Hopefully this IP will be routable but unresponsive.
+ // (Alternatively, we could try listening on a local raw socket, but that
+ // normally requires root privileges.)
+ auto host = SocketAddressTestHelper::isIPv6Enabled()
+ ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
+ : SocketAddressTestHelper::isIPv4Enabled()
+ ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
+ : nullptr;
+ SocketAddress addr(host, 65535);
+
+ // Connect using a AsyncSocket
+ EventBase evb;
+ auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
+ socket->enableTFO();
+
+ ConnCallback ccb;
+ // Set a very small timeout
+ socket->connect(&ccb, addr, 1);
+ EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
+
+ ReadCallback rcb;
+ socket->setReadCB(&rcb);
+
+ EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+ .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
+ sockaddr_storage addr2;
+ auto len = addr.getAddress(&addr2);
+ return connect(fd, (const struct sockaddr*)&addr2, len);
+ }));
+ WriteCallback write;
+ socket->writeChain(&write, IOBuf::copyBuffer("hey"));
+
+ evb.loop();
+
+ EXPECT_EQ(STATE_FAILED, write.state);
+}
+
+TEST(AsyncSocketTest, TestTFOEagain) {
+ TestServer server(true);
+
+ // Connect using a AsyncSocket
+ EventBase evb;
+ auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
+ socket->enableTFO();
+
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+ .WillOnce(SetErrnoAndReturn(EAGAIN, -1));
+ WriteCallback write;
+ socket->writeChain(&write, IOBuf::copyBuffer("hey"));
+
+ evb.loop();
+
+ EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
+ EXPECT_EQ(STATE_FAILED, write.state);
+}
+
+// Sending a large amount of data in the first write which will
+// definitely not fit into MSS.
+TEST(AsyncSocketTest, ConnectTFOWithBigData) {
+ // Start listening on a local port
+ TestServer server(true);
+
+ // Connect using a AsyncSocket
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+ socket->enableTFO();
+ ConnCallback cb;
+ socket->connect(&cb, server.getAddress(), 30);
+
+ std::array<uint8_t, 128> buf;
+ memset(buf.data(), 'a', buf.size());
+
+ constexpr size_t len = 10 * 1024;
+ auto sendBuf = IOBuf::create(len);
+ sendBuf->append(len);
+ std::array<uint8_t, len> readBuf;
+
+ std::thread t([&] {
+ auto acceptedSocket = server.accept();
+ acceptedSocket->write(buf.data(), buf.size());
+ acceptedSocket->flush();
+ acceptedSocket->readAll(readBuf.data(), readBuf.size());
+ acceptedSocket->close();
+ });
+
+ evb.loop();
+
+ CHECK_EQ(cb.state, STATE_SUCCEEDED);
+ EXPECT_LE(0, socket->getConnectTime().count());
+ EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
+ EXPECT_TRUE(socket->getTFOAttempted());
+
+ // Should trigger the connect
+ WriteCallback write;
+ ReadCallback rcb;
+ socket->writeChain(&write, sendBuf->clone());
+ socket->setReadCB(&rcb);
+ evb.loop();
+
+ t.join();
+
+ EXPECT_EQ(STATE_SUCCEEDED, write.state);
+ EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
+ EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
+ ASSERT_EQ(1, rcb.buffers.size());
+ ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
+ EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
+}
+
+#endif