From: Subodh Iyengar Date: Fri, 25 Nov 2016 05:18:22 +0000 (-0800) Subject: Fix TFO refused case X-Git-Tag: v2016.11.28.00~4 X-Git-Url: http://demsky.eecs.uci.edu/git/?a=commitdiff_plain;h=16dc0043e4d9ad309b6d66565511181732ff0827;p=folly.git Fix TFO refused case Summary: When TFO falls back, it's possible that the fallback can also error out. We handle this correctly in AsyncSocket, however because AsyncSSLSocket is so inter-twined with AsyncSocket, we missed the case of error as well. This changes it so that a connect error on fallback will cause a handshake error Differential Revision: D4226477 fbshipit-source-id: c6e845e4a907bfef1e6ad1b4118db47184d047e0 --- diff --git a/folly/io/async/AsyncSSLSocket.cpp b/folly/io/async/AsyncSSLSocket.cpp index 8804b8d1..f8e6b471 100644 --- a/folly/io/async/AsyncSSLSocket.cpp +++ b/folly/io/async/AsyncSSLSocket.cpp @@ -1127,11 +1127,21 @@ AsyncSSLSocket::handleConnect() noexcept { void AsyncSSLSocket::invokeConnectErr(const AsyncSocketException& ex) { connectionTimeout_.cancelTimeout(); AsyncSocket::invokeConnectErr(ex); + if (sslState_ == SSLStateEnum::STATE_CONNECTING) { + assert(tfoAttempted_); + if (handshakeTimeout_.isScheduled()) { + handshakeTimeout_.cancelTimeout(); + } + // If we fell back to connecting state during TFO and the connection + // failed, it would be an SSL failure as well. + invokeHandshakeErr(ex); + } } void AsyncSSLSocket::invokeConnectSuccess() { connectionTimeout_.cancelTimeout(); if (sslState_ == SSLStateEnum::STATE_CONNECTING) { + assert(tfoAttempted_); // If we failed TFO, we'd fall back to trying to connect the socket, // to setup things like timeouts. startSSLConnect(); diff --git a/folly/io/async/AsyncSocket.cpp b/folly/io/async/AsyncSocket.cpp index 652f90f5..729185b8 100644 --- a/folly/io/async/AsyncSocket.cpp +++ b/folly/io/async/AsyncSocket.cpp @@ -1798,8 +1798,8 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) { errno = EAGAIN; totalWritten = -1; } else if (errno == EOPNOTSUPP) { - VLOG(4) << "TFO not supported"; // Try falling back to connecting. + VLOG(4) << "TFO not supported"; state_ = StateEnum::CONNECTING; try { int ret = socketConnect((const sockaddr*)&addr, len); @@ -1977,12 +1977,7 @@ void AsyncSocket::startFail() { } } -void AsyncSocket::finishFail() { - assert(state_ == StateEnum::ERROR); - assert(getDestructorGuardCount() > 0); - - AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR, - withAddr("socket closing after error")); +void AsyncSocket::invokeAllErrors(const AsyncSocketException& ex) { invokeConnectErr(ex); failAllWrites(ex); @@ -1993,6 +1988,22 @@ void AsyncSocket::finishFail() { } } +void AsyncSocket::finishFail() { + assert(state_ == StateEnum::ERROR); + assert(getDestructorGuardCount() > 0); + + AsyncSocketException ex( + AsyncSocketException::INTERNAL_ERROR, + withAddr("socket closing after error")); + invokeAllErrors(ex); +} + +void AsyncSocket::finishFail(const AsyncSocketException& ex) { + assert(state_ == StateEnum::ERROR); + assert(getDestructorGuardCount() > 0); + invokeAllErrors(ex); +} + void AsyncSocket::fail(const char* fn, const AsyncSocketException& ex) { VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state=" << state_ << " host=" << addr_.describe() @@ -2010,7 +2021,7 @@ void AsyncSocket::failConnect(const char* fn, const AsyncSocketException& ex) { startFail(); invokeConnectErr(ex); - finishFail(); + finishFail(ex); } void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) { diff --git a/folly/io/async/AsyncSocket.h b/folly/io/async/AsyncSocket.h index 3f5d715d..2fed39a0 100644 --- a/folly/io/async/AsyncSocket.h +++ b/folly/io/async/AsyncSocket.h @@ -877,6 +877,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper { // error handling methods void startFail(); void finishFail(); + void finishFail(const AsyncSocketException& ex); + void invokeAllErrors(const AsyncSocketException& ex); void fail(const char* fn, const AsyncSocketException& ex); void failConnect(const char* fn, const AsyncSocketException& ex); void failRead(const char* fn, const AsyncSocketException& ex); diff --git a/folly/io/async/test/AsyncSSLSocketTest.cpp b/folly/io/async/test/AsyncSSLSocketTest.cpp index 65c056f6..bc7c4f48 100644 --- a/folly/io/async/test/AsyncSSLSocketTest.cpp +++ b/folly/io/async/test/AsyncSSLSocketTest.cpp @@ -1918,6 +1918,21 @@ TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) { EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out")); } +TEST(AsyncSSLSocketTest, HandshakeTFORefused) { + // Start listening on a local port + EventBase evb; + + // Hopefully nothing is listening on this address + SocketAddress addr("127.0.0.1", 65535); + auto socket = setupSocketWithFallback(&evb, addr, AtMost(1)); + ConnCallback ccb; + socket->connect(&ccb, addr, 100); + + evb.loop(); + EXPECT_EQ(ConnCallback::State::ERROR, ccb.state); + EXPECT_THAT(ccb.error, testing::HasSubstr("refused")); +} + #endif } // namespace diff --git a/folly/io/async/test/AsyncSocketTest2.cpp b/folly/io/async/test/AsyncSocketTest2.cpp index 36ccc305..afe23fa1 100644 --- a/folly/io/async/test/AsyncSocketTest2.cpp +++ b/folly/io/async/test/AsyncSocketTest2.cpp @@ -2524,7 +2524,7 @@ TEST(AsyncSocketTest, ConnectTFOSupplyEarlyReadCB) { /** * Test connecting to a server that isn't listening */ -TEST(AsyncSocketTest, ConnectRefusedTFO) { +TEST(AsyncSocketTest, ConnectRefusedImmediatelyTFO) { EventBase evb; std::shared_ptr socket = AsyncSocket::newSocket(&evb); @@ -2541,7 +2541,6 @@ TEST(AsyncSocketTest, ConnectRefusedTFO) { WriteCallback write1; // Trigger the connect if TFO attempt is supported. socket->writeChain(&write1, IOBuf::copyBuffer("hey")); - evb.loop(); WriteCallback write2; socket->writeChain(&write2, IOBuf::copyBuffer("hey")); evb.loop(); @@ -2675,6 +2674,51 @@ TEST(AsyncSocketTest, TestTFOUnsupported) { EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded()); } +TEST(AsyncSocketTest, ConnectRefusedDelayedTFO) { + EventBase evb; + + auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb)); + socket->enableTFO(); + + // Hopefully this fails + folly::SocketAddress fakeAddr("127.0.0.1", 65535); + EXPECT_CALL(*socket, tfoSendMsg(_, _, _)) + .WillOnce(Invoke([&](int fd, struct msghdr*, int) { + sockaddr_storage addr; + auto len = fakeAddr.getAddress(&addr); + int ret = connect(fd, (const struct sockaddr*)&addr, len); + LOG(INFO) << "connecting the socket " << fd << " : " << ret << " : " + << errno; + return ret; + })); + + // Hopefully nothing is actually listening on this address + ConnCallback cb; + socket->connect(&cb, fakeAddr, 30); + + WriteCallback write1; + // Trigger the connect if TFO attempt is supported. + socket->writeChain(&write1, IOBuf::copyBuffer("hey")); + + if (socket->getTFOFinished()) { + // This test is useless now. + return; + } + WriteCallback write2; + // Trigger the connect if TFO attempt is supported. + socket->writeChain(&write2, IOBuf::copyBuffer("hey")); + evb.loop(); + + EXPECT_EQ(STATE_FAILED, write1.state); + EXPECT_EQ(STATE_FAILED, write2.state); + EXPECT_FALSE(socket->getTFOSucceded()); + + EXPECT_EQ(STATE_SUCCEEDED, cb.state); + EXPECT_LE(0, socket->getConnectTime().count()); + EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout()); + EXPECT_TRUE(socket->getTFOAttempted()); +} + TEST(AsyncSocketTest, TestTFOUnsupportedTimeout) { // Try connecting to server that won't respond. // diff --git a/folly/io/async/test/SocketClient.cpp b/folly/io/async/test/SocketClient.cpp index 1a93ab95..2fbb546b 100644 --- a/folly/io/async/test/SocketClient.cpp +++ b/folly/io/async/test/SocketClient.cpp @@ -27,6 +27,9 @@ DEFINE_int32(port, 0, "port"); DEFINE_bool(tfo, false, "enable tfo"); DEFINE_string(msg, "", "Message to send"); DEFINE_bool(ssl, false, "use ssl"); +DEFINE_int32(timeout_ms, 0, "timeout"); +DEFINE_int32(sendtimeout_ms, 0, "send timeout"); +DEFINE_int32(num_writes, 1, "number of writes"); int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -53,6 +56,10 @@ int main(int argc, char** argv) { #endif } + if (FLAGS_sendtimeout_ms != 0) { + socket->setSendTimeout(FLAGS_sendtimeout_ms); + } + // Keep this around auto sockAddr = socket.get(); @@ -60,10 +67,13 @@ int main(int argc, char** argv) { SocketAddress addr; addr.setFromHostPort(FLAGS_host, FLAGS_port); sock.setAddress(addr); - sock.open(); + std::chrono::milliseconds timeout(FLAGS_timeout_ms); + sock.open(timeout); LOG(INFO) << "connected to " << addr.getAddressStr(); - sock.write((const uint8_t*)FLAGS_msg.data(), FLAGS_msg.size()); + for (int32_t i = 0; i < FLAGS_num_writes; ++i) { + sock.write((const uint8_t*)FLAGS_msg.data(), FLAGS_msg.size()); + } LOG(INFO) << "TFO attempted: " << sockAddr->getTFOAttempted(); LOG(INFO) << "TFO finished: " << sockAddr->getTFOFinished();