From b669462b65cacda010d6dca11bc56f9aee768ebc Mon Sep 17 00:00:00 2001
From: Neel Goyal <ngoyal@fb.com>
Date: Tue, 1 Aug 2017 12:18:25 -0700
Subject: [PATCH] Add TLS 1.2+ version for contexts

Summary: Add an SSL Version that specifies only TLS 1.2 and up.  This prevents any client with less than TLS 1.2 from connecting.

Reviewed By: knekritz

Differential Revision: D5537423

fbshipit-source-id: 131f5b124af379eaa2b443052be9b43290c41820
---
 folly/io/async/SSLContext.cpp               |  4 +
 folly/io/async/SSLContext.h                 |  8 +-
 folly/io/async/test/AsyncSSLSocketTest2.cpp | 87 +++++++++++++++++++++
 folly/io/async/test/TestSSLServer.cpp       |  5 ++
 folly/io/async/test/TestSSLServer.h         |  2 +
 5 files changed, 102 insertions(+), 4 deletions(-)

diff --git a/folly/io/async/SSLContext.cpp b/folly/io/async/SSLContext.cpp
index 3d440a8b..45936d0c 100644
--- a/folly/io/async/SSLContext.cpp
+++ b/folly/io/async/SSLContext.cpp
@@ -49,6 +49,10 @@ SSLContext::SSLContext(SSLVersion version) {
     case SSLv3:
       opt = SSL_OP_NO_SSLv2;
       break;
+    case TLSv1_2:
+      opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 |
+          SSL_OP_NO_TLSv1_1;
+      break;
     default:
       // do nothing
       break;
diff --git a/folly/io/async/SSLContext.h b/folly/io/async/SSLContext.h
index ded583f5..d556806f 100644
--- a/folly/io/async/SSLContext.h
+++ b/folly/io/async/SSLContext.h
@@ -68,11 +68,11 @@ class PasswordCollector {
  */
 class SSLContext {
  public:
-
   enum SSLVersion {
-     SSLv2,
-     SSLv3,
-     TLSv1
+    SSLv2,
+    SSLv3,
+    TLSv1, // support TLS 1.0+
+    TLSv1_2, // support for only TLS 1.2+
   };
 
   /**
diff --git a/folly/io/async/test/AsyncSSLSocketTest2.cpp b/folly/io/async/test/AsyncSSLSocketTest2.cpp
index eb5490c9..ae9ef53f 100644
--- a/folly/io/async/test/AsyncSSLSocketTest2.cpp
+++ b/folly/io/async/test/AsyncSSLSocketTest2.cpp
@@ -190,6 +190,93 @@ TEST(AsyncSSLSocketTest2, AttachDetachSSLContext) {
   EXPECT_TRUE(f.within(std::chrono::seconds(3)).get());
 }
 
+class ConnectClient : public AsyncSocket::ConnectCallback {
+ public:
+  ConnectClient() = default;
+
+  Future<bool> getFuture() {
+    return promise_.getFuture();
+  }
+
+  void connect(const folly::SocketAddress& addr) {
+    t1_.getEventBase()->runInEventBaseThread([&] {
+      socket_ = t1_.createSocket();
+      socket_->connect(this, addr);
+    });
+  }
+
+  void connectSuccess() noexcept override {
+    promise_.setValue(true);
+    socket_.reset();
+  }
+
+  void connectErr(const AsyncSocketException& /* ex */) noexcept override {
+    promise_.setValue(false);
+    socket_.reset();
+  }
+
+  void setCtx(std::shared_ptr<SSLContext> ctx) {
+    t1_.ctx_ = ctx;
+  }
+
+ private:
+  EvbAndContext t1_;
+  // promise to fulfill when done with a value of true if connect succeeded
+  folly::Promise<bool> promise_;
+  std::shared_ptr<AsyncSSLSocket> socket_;
+};
+
+class NoopReadCallback : public ReadCallbackBase {
+ public:
+  NoopReadCallback() : ReadCallbackBase(nullptr) {
+    state = STATE_SUCCEEDED;
+  }
+
+  void getReadBuffer(void** buf, size_t* lenReturn) override {
+    *buf = &buffer_;
+    *lenReturn = 1;
+  }
+  void readDataAvailable(size_t) noexcept override {}
+
+  uint8_t buffer_{0};
+};
+
+TEST(AsyncSSLSocketTest2, TestTLS12DefaultClient) {
+  // Start listening on a local port
+  NoopReadCallback readCallback;
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
+  auto ctx = std::make_shared<SSLContext>(SSLContext::TLSv1_2);
+  TestSSLServer server(&acceptCallback, ctx);
+  server.loadTestCerts();
+
+  // create a default client
+  auto c1 = std::make_unique<ConnectClient>();
+  auto f1 = c1->getFuture();
+  c1->connect(server.getAddress());
+  EXPECT_TRUE(f1.within(std::chrono::seconds(3)).get());
+}
+
+TEST(AsyncSSLSocketTest2, TestTLS12BadClient) {
+  // Start listening on a local port
+  NoopReadCallback readCallback;
+  HandshakeCallback handshakeCallback(
+      &readCallback, HandshakeCallback::EXPECT_ERROR);
+  SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
+  auto ctx = std::make_shared<SSLContext>(SSLContext::TLSv1_2);
+  TestSSLServer server(&acceptCallback, ctx);
+  server.loadTestCerts();
+
+  // create a client that doesn't speak TLS 1.2
+  auto c2 = std::make_unique<ConnectClient>();
+  auto clientCtx = std::make_shared<SSLContext>();
+  clientCtx->setOptions(SSL_OP_NO_TLSv1_2);
+  c2->setCtx(clientCtx);
+  auto f2 = c2->getFuture();
+  c2->connect(server.getAddress());
+  EXPECT_FALSE(f2.within(std::chrono::seconds(3)).get());
+}
+
 } // namespace folly
 
 int main(int argc, char *argv[]) {
diff --git a/folly/io/async/test/TestSSLServer.cpp b/folly/io/async/test/TestSSLServer.cpp
index bc127db4..46d6743d 100644
--- a/folly/io/async/test/TestSSLServer.cpp
+++ b/folly/io/async/test/TestSSLServer.cpp
@@ -40,6 +40,11 @@ TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
   init(enableTFO);
 }
 
+void TestSSLServer::loadTestCerts() {
+  ctx_->loadCertificate(kTestCert);
+  ctx_->loadPrivateKey(kTestKey);
+}
+
 TestSSLServer::TestSSLServer(
     SSLServerAcceptCallbackBase* acb,
     std::shared_ptr<SSLContext> ctx,
diff --git a/folly/io/async/test/TestSSLServer.h b/folly/io/async/test/TestSSLServer.h
index a25bbb89..506b3d13 100644
--- a/folly/io/async/test/TestSSLServer.h
+++ b/folly/io/async/test/TestSSLServer.h
@@ -111,6 +111,8 @@ class TestSSLServer {
     return evb_;
   }
 
+  void loadTestCerts();
+
   const SocketAddress& getAddress() const {
     return address_;
   }
-- 
2.34.1