Summary:
Adds a StartTLS mode to AsyncSSLSocket. Previously I could only find anyone doing something like this by using AsyncSocket, calling detachFd, then creating a new AsyncSSLSocket, and calling sslConn/sslAccept.
That had a couple downsides: 1) All pointers to the previous AsyncSocket become invalid and similarly 2) have to be super careful reads/writes happen on the correct socket, are flushed before changing socket types, etc.
This makes it super easy to just use the same AsyncSSLSocket for everything:
a) Create AsyncSSLSocket in StartTLS mode
b) send/recv anything
c) Call sslAccept/sslConn. Existing writes are still flushed in the correct order, any additional writes are buffered until handshake completes
d) Start receiving encrypted data.
I made it a new mode (vs. the default), since it seems bad to unintentionally send unencrypted data.
Use case is easy secure thrift upgrade (similar to how current kerberos does it)
Test Plan: New unittest
Reviewed By: afrind@fb.com
Subscribers: doug, ssl-diffs@, folly-diffs@, yfeldblum, chalfant, haijunz, andrewcox, alandau, alikhtarov, jsedgwick, simpkins
FB internal diff:
D2120114
Signature: t1:
2120114:
1433798448:
caeddc8feb6cc10fb34200ba97ea323bcaf09f7a
* Create a client AsyncSSLSocket
*/
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
- EventBase* evb) :
+ EventBase* evb, bool deferSecurityNegotiation) :
AsyncSocket(evb),
ctx_(ctx),
handshakeTimeout_(this, evb) {
init();
+ if (deferSecurityNegotiation) {
+ sslState_ = STATE_UNENCRYPTED;
+ }
}
/**
* Create a server/client AsyncSSLSocket
*/
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
- EventBase* evb, int fd, bool server) :
+ EventBase* evb, int fd, bool server,
+ bool deferSecurityNegotiation) :
AsyncSocket(evb, fd),
server_(server),
ctx_(ctx),
SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
AsyncSSLSocket::sslInfoCallback);
}
+ if (deferSecurityNegotiation) {
+ sslState_ = STATE_UNENCRYPTED;
+ }
}
#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
*/
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
EventBase* evb,
- const std::string& serverName) :
- AsyncSSLSocket(ctx, evb) {
+ const std::string& serverName,
+ bool deferSecurityNegotiation) :
+ AsyncSSLSocket(ctx, evb, deferSecurityNegotiation) {
tlsextHostname_ = serverName;
}
*/
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
EventBase* evb, int fd,
- const std::string& serverName) :
- AsyncSSLSocket(ctx, evb, fd, false) {
+ const std::string& serverName,
+ bool deferSecurityNegotiation) :
+ AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
tlsextHostname_ = serverName;
}
#endif
bool AsyncSSLSocket::good() const {
return (AsyncSocket::good() &&
(sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING ||
- sslState_ == STATE_ESTABLISHED));
+ sslState_ == STATE_ESTABLISHED || sslState_ == STATE_UNENCRYPTED));
}
// The TAsyncTransport definition of 'good' states that the transport is
verifyPeer_ = verifyPeer;
// Make sure we're in the uninitialized state
- if (!server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) {
+ if (!server_ || (sslState_ != STATE_UNINIT &&
+ sslState_ != STATE_UNENCRYPTED) ||
+ handshakeCallback_ != nullptr) {
return invalidState(callback);
}
verifyPeer_ = verifyPeer;
// Make sure we're in the uninitialized state
- if (server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) {
+ if (server_ || (sslState_ != STATE_UNINIT && sslState_ !=
+ STATE_UNENCRYPTED) ||
+ handshakeCallback_ != nullptr) {
return invalidState(callback);
}
ssize_t
AsyncSSLSocket::performRead(void* buf, size_t buflen) {
+ if (sslState_ == STATE_UNENCRYPTED) {
+ return AsyncSocket::performRead(buf, buflen);
+ }
+
errno = 0;
ssize_t bytes = SSL_read(ssl_, buf, buflen);
if (server_ && renegotiateAttempted_) {
WriteFlags flags,
uint32_t* countWritten,
uint32_t* partialWritten) {
+ if (sslState_ == STATE_UNENCRYPTED) {
+ return AsyncSocket::performWrite(
+ vec, count, flags, countWritten, partialWritten);
+ }
if (sslState_ != STATE_ESTABLISHED) {
LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
<< ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
* Create a client AsyncSSLSocket
*/
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext> &ctx,
- EventBase* evb);
+ EventBase* evb, bool deferSecurityNegotiation = false);
/**
* Create a server/client AsyncSSLSocket from an already connected
* @param evb EventBase that will manage this socket.
* @param fd File descriptor to take over (should be a connected socket).
* @param server Is socket in server mode?
+ * @param deferSecurityNegotiation
+ * unencrypted data can be sent before sslConn/Accept
*/
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
- EventBase* evb, int fd, bool server = true);
+ EventBase* evb, int fd,
+ bool server = true, bool deferSecurityNegotiation = false);
/**
*/
static std::shared_ptr<AsyncSSLSocket> newSocket(
const std::shared_ptr<folly::SSLContext>& ctx,
- EventBase* evb, int fd, bool server=true) {
+ EventBase* evb, int fd, bool server=true,
+ bool deferSecurityNegotiation = false) {
return std::shared_ptr<AsyncSSLSocket>(
- new AsyncSSLSocket(ctx, evb, fd, server),
+ new AsyncSSLSocket(ctx, evb, fd, server, deferSecurityNegotiation),
Destructor());
}
*/
static std::shared_ptr<AsyncSSLSocket> newSocket(
const std::shared_ptr<folly::SSLContext>& ctx,
- EventBase* evb) {
+ EventBase* evb, bool deferSecurityNegotiation = false) {
return std::shared_ptr<AsyncSSLSocket>(
- new AsyncSSLSocket(ctx, evb),
+ new AsyncSSLSocket(ctx, evb, deferSecurityNegotiation),
Destructor());
}
*/
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext> &ctx,
EventBase* evb,
- const std::string& serverName);
+ const std::string& serverName,
+ bool deferSecurityNegotiation = false);
/**
* Create a client AsyncSSLSocket from an already connected
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb,
int fd,
- const std::string& serverName);
+ const std::string& serverName,
+ bool deferSecurityNegotiation = false);
static std::shared_ptr<AsyncSSLSocket> newSocket(
const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb,
- const std::string& serverName) {
+ const std::string& serverName,
+ bool deferSecurityNegotiation = false) {
return std::shared_ptr<AsyncSSLSocket>(
- new AsyncSSLSocket(ctx, evb, serverName),
+ new AsyncSSLSocket(ctx, evb, serverName, deferSecurityNegotiation),
Destructor());
}
#endif
enum SSLStateEnum {
STATE_UNINIT,
+ STATE_UNENCRYPTED,
STATE_ACCEPTING,
STATE_CACHE_LOOKUP,
STATE_RSA_ASYNC_PENDING,
socket->setMinWriteSize(50000);
EXPECT_EQ(50000, socket->getMinWriteSize());
}
+
+class ReadCallbackTerminator : public ReadCallback {
+ public:
+ ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
+ : ReadCallback(wcb)
+ , base_(base) {}
+
+ // Do not write data back, terminate the loop.
+ void readDataAvailable(size_t len) noexcept override {
+ std::cerr << "readDataAvailable, len " << len << std::endl;
+
+ currentBuffer.length = len;
+
+ buffers.push_back(currentBuffer);
+ currentBuffer.reset();
+ state = STATE_SUCCEEDED;
+
+ socket_->setReadCB(nullptr);
+ base_->terminateLoopSoon();
+ }
+ private:
+ EventBase* base_;
+};
+
+
+/**
+ * Test a full unencrypted codepath
+ */
+TEST(AsyncSSLSocketTest, UnencryptedTest) {
+ EventBase base;
+
+ auto clientCtx = std::make_shared<folly::SSLContext>();
+ auto serverCtx = std::make_shared<folly::SSLContext>();
+ int fds[2];
+ getfds(fds);
+ getctx(clientCtx, serverCtx);
+ auto client = AsyncSSLSocket::newSocket(
+ clientCtx, &base, fds[0], false, true);
+ auto server = AsyncSSLSocket::newSocket(
+ serverCtx, &base, fds[1], true, true);
+
+ ReadCallbackTerminator readCallback(&base, nullptr);
+ server->setReadCB(&readCallback);
+ readCallback.setSocket(server);
+
+ uint8_t buf[128];
+ memset(buf, 'a', sizeof(buf));
+ client->write(nullptr, buf, sizeof(buf));
+
+ // Check that bytes are unencrypted
+ char c;
+ EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
+ EXPECT_EQ('a', c);
+
+ EventBaseAborter eba(&base, 3000);
+ base.loop();
+
+ EXPECT_EQ(1, readCallback.buffers.size());
+ EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
+
+ server->setReadCB(&readCallback);
+
+ // Unencrypted
+ server->sslAccept(nullptr);
+ client->sslConn(nullptr);
+
+ // Do NOT wait for handshake, writing should be queued and happen after
+
+ client->write(nullptr, buf, sizeof(buf));
+
+ // Check that bytes are *not* unencrypted
+ char c2;
+ EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
+ EXPECT_NE('a', c2);
+
+
+ base.loop();
+
+ EXPECT_EQ(2, readCallback.buffers.size());
+ EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
}
+} // namespace
+
///////////////////////////////////////////////////////////////////////////
// init_unit_test_suite
///////////////////////////////////////////////////////////////////////////