EventBase* evb, bool deferSecurityNegotiation) :
AsyncSocket(evb),
ctx_(ctx),
- handshakeTimeout_(this, evb) {
+ handshakeTimeout_(this, evb),
+ connectionTimeout_(this, evb) {
init();
if (deferSecurityNegotiation) {
sslState_ = STATE_UNENCRYPTED;
AsyncSocket(evb, fd),
server_(server),
ctx_(ctx),
- handshakeTimeout_(this, evb) {
+ handshakeTimeout_(this, evb),
+ connectionTimeout_(this, evb) {
init();
if (server) {
SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
// We are expecting a callback in restartSSLAccept. The cache lookup
// and rsa-call necessarily have pointers to this ssl socket, so delay
// the cleanup until he calls us back.
+ } else if (state_ == StateEnum::CONNECTING) {
+ assert(sslState_ == STATE_CONNECTING);
+ DestructorGuard dg(this);
+ AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
+ "Fallback connect timed out during TFO");
+ failHandshake(__func__, ex);
} else {
assert(state_ == StateEnum::ESTABLISHED &&
(sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING));
AsyncSocket::handleInitialReadWrite();
}
+void AsyncSSLSocket::invokeConnectErr(const AsyncSocketException& ex) {
+ connectionTimeout_.cancelTimeout();
+ AsyncSocket::invokeConnectErr(ex);
+}
+
void AsyncSSLSocket::invokeConnectSuccess() {
+ connectionTimeout_.cancelTimeout();
if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
// If we failed TFO, we'd fall back to trying to connect the socket,
// to setup things like timeouts.
startSSLConnect();
}
+ // still invoke the base class since it re-sets the connect time.
AsyncSocket::invokeConnectSuccess();
}
+void AsyncSSLSocket::scheduleConnectTimeout() {
+ if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
+ // We fell back from TFO, and need to set the timeouts.
+ // We will not have a connect callback in this case, thus if the timer
+ // expires we would have no-one to notify.
+ // Thus we should reset even the connect timers to point to the handshake
+ // timeouts.
+ assert(connectCallback_ == nullptr);
+ // We use a different connect timeout here than the handshake timeout, so
+ // that we can disambiguate the 2 timers.
+ int timeout = connectTimeout_.count();
+ if (timeout > 0) {
+ if (!connectionTimeout_.scheduleTimeout(timeout)) {
+ throw AsyncSocketException(
+ AsyncSocketException::INTERNAL_ERROR,
+ withAddr("failed to schedule AsyncSSLSocket connect timeout"));
+ }
+ }
+ return;
+ }
+ AsyncSocket::scheduleConnectTimeout();
+}
+
void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
#ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
// turn on the buffer movable in openssl
AsyncSSLSocket* sslSocket_;
};
+ // Timer for if we fallback from SSL connects to TCP connects
+ class ConnectionTimeout : public AsyncTimeout {
+ public:
+ ConnectionTimeout(AsyncSSLSocket* sslSocket, EventBase* eventBase)
+ : AsyncTimeout(eventBase), sslSocket_(sslSocket) {}
+
+ virtual void timeoutExpired() noexcept override {
+ sslSocket_->timeoutExpired();
+ }
+
+ private:
+ AsyncSSLSocket* sslSocket_;
+ };
+
/**
* Create a client AsyncSSLSocket
*/
void invokeHandshakeErr(const AsyncSocketException& ex);
void invokeHandshakeCB();
+ void invokeConnectErr(const AsyncSocketException& ex) override;
void invokeConnectSuccess() override;
+ void scheduleConnectTimeout() override;
void cacheLocalPeerAddr();
SSL* ssl_{nullptr};
SSL_SESSION *sslSession_{nullptr};
HandshakeTimeout handshakeTimeout_;
+ ConnectionTimeout connectionTimeout_;
// whether the SSL session was resumed using session ID or not
bool sessionIDResumed_{false};
if (rv < 0) {
auto errnoCopy = errno;
if (errnoCopy == EINPROGRESS) {
- scheduleConnectTimeoutAndRegisterForEvents();
+ scheduleConnectTimeout();
+ registerForConnectEvents();
} else {
throw AsyncSocketException(
AsyncSocketException::NOT_OPEN,
return rv;
}
-void AsyncSocket::scheduleConnectTimeoutAndRegisterForEvents() {
+void AsyncSocket::scheduleConnectTimeout() {
// Connection in progress.
int timeout = connectTimeout_.count();
if (timeout > 0) {
withAddr("failed to schedule AsyncSocket connect timeout"));
}
}
+}
+void AsyncSocket::registerForConnectEvents() {
// Register for write events, so we'll
// be notified when the connection finishes/fails.
// Note that we don't register for a persistent event here.
// cookie.
state_ = StateEnum::CONNECTING;
try {
- scheduleConnectTimeoutAndRegisterForEvents();
+ scheduleConnectTimeout();
+ registerForConnectEvents();
} catch (const AsyncSocketException& ex) {
return WriteResult(
WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
int socketConnect(const struct sockaddr* addr, socklen_t len);
- void scheduleConnectTimeoutAndRegisterForEvents();
+ virtual void scheduleConnectTimeout();
+ void registerForConnectEvents();
bool updateEventRegistration();
const AsyncSocketException& ex);
void failWrite(const char* fn, const AsyncSocketException& ex);
void failAllWrites(const AsyncSocketException& ex);
- void invokeConnectErr(const AsyncSocketException& ex);
+ virtual void invokeConnectErr(const AsyncSocketException& ex);
virtual void invokeConnectSuccess();
void invalidState(ConnectCallback* callback);
void invalidState(ReadCallback* callback);
state = State::SUCCESS;
}
- virtual void connectErr(const AsyncSocketException&) noexcept override {
+ virtual void connectErr(const AsyncSocketException& ex) noexcept override {
state = State::ERROR;
+ error = ex.what();
}
enum class State { WAITING, SUCCESS, ERROR };
State state{State::WAITING};
+ std::string error;
};
template <class Cardinality>
std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->enableTFO();
EXPECT_THROW(
- socket->open(std::chrono::milliseconds(1)), AsyncSocketException);
+ socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
}
TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
}
+TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
+ // Start listening on a local port
+ EmptyReadCallback readCallback;
+ HandshakeCallback handshakeCallback(
+ &readCallback, HandshakeCallback::EXPECT_ERROR);
+ HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback, true);
+
+ EventBase evb;
+
+ auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 100);
+
+ evb.loop();
+ EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+ EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
+}
+
#endif
} // namespace