Revise API to load cert/key in SSLContext.
authorXiangyu Bu <xbu@fb.com>
Mon, 27 Nov 2017 23:37:45 +0000 (15:37 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 27 Nov 2017 23:50:50 +0000 (15:50 -0800)
Summary:
When loading cert/key pair, order matters:
(a) Wrong key will fail to load if a cert is loaded;
(b) Wrong cert will succeed to load even if a private key is loaded.

So this diff adds:
(1) SSLContext::checkPrivateKey() -- must call for case (b).
(2) SSLContext::loadCertKeyPairFromBufferPEM() -- use this if one loads both cert and key. Guaranteed to throw if cert/key mismatch.

Reviewed By: yfeldblum

Differential Revision: D6416280

fbshipit-source-id: 8ae370883d46e9b5afb69c506c09fbf7ba82b1b9

folly/io/async/SSLContext.cpp
folly/io/async/SSLContext.h
folly/io/async/test/SSLContextTest.cpp

index 498d8dfb3dfa76fcf6a64ac2288b23a9ea42925b..7ce05f44a7a062c974ffe8a65dd17c1bc1293f40 100644 (file)
@@ -287,6 +287,26 @@ void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
   }
 }
 
+void SSLContext::loadCertKeyPairFromBufferPEM(
+    folly::StringPiece cert,
+    folly::StringPiece pkey) {
+  loadCertificateFromBufferPEM(cert);
+  loadPrivateKeyFromBufferPEM(pkey);
+}
+
+void SSLContext::loadCertKeyPairFromFiles(
+    const char* certPath,
+    const char* keyPath,
+    const char* certFormat,
+    const char* keyFormat) {
+  loadCertificate(certPath, certFormat);
+  loadPrivateKey(keyPath, keyFormat);
+}
+
+bool SSLContext::isCertKeyPairValid() const {
+  return SSL_CTX_check_private_key(ctx_) == 1;
+}
+
 void SSLContext::loadTrustedCertificates(const char* path) {
   if (path == nullptr) {
     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
index c8db033eb525e780599c920642e560e0bfba1b9e..bdd04509119d4c82b34d1d769716ea9b9a57e7b9 100644 (file)
@@ -275,6 +275,7 @@ class SSLContext {
    * @param cert  A PEM formatted certificate
    */
   virtual void loadCertificateFromBufferPEM(folly::StringPiece cert);
+
   /**
    * Load private key.
    *
@@ -288,6 +289,41 @@ class SSLContext {
    * @param pkey  A PEM formatted key
    */
   virtual void loadPrivateKeyFromBufferPEM(folly::StringPiece pkey);
+
+  /**
+   * Load cert and key from PEM buffers. Guaranteed to throw if cert and
+   * private key mismatch so no need to call isCertKeyPairValid.
+   *
+   * @param cert A PEM formatted certificate
+   * @param pkey A PEM formatted key
+   */
+  virtual void loadCertKeyPairFromBufferPEM(
+      folly::StringPiece cert,
+      folly::StringPiece pkey);
+
+  /**
+   * Load cert and key from files. Guaranteed to throw if cert and key mismatch.
+   * Equivalent to calling loadCertificate() and loadPrivateKey().
+   *
+   * @param certPath   Path to the certificate file
+   * @param keyPath   Path to the private key file
+   * @param certFormat Certificate file format
+   * @param keyFormat Private key file format
+   */
+  virtual void loadCertKeyPairFromFiles(
+      const char* certPath,
+      const char* keyPath,
+      const char* certFormat = "PEM",
+      const char* keyFormat = "PEM");
+
+  /**
+   * Call after both cert and key are loaded to check if cert matches key.
+   * Must call if private key is loaded before loading the cert.
+   * No need to call if cert is loaded first before private key.
+   * @return true if matches, or false if mismatch.
+   */
+  virtual bool isCertKeyPairValid() const;
+
   /**
    * Load trusted certificates from specified file.
    *
index 955a1fc8669589fa83be4cadbba7a27c18f6e4df..fafa433b375184a6c6a5db9fb8565af26612a72c 100644 (file)
@@ -15,6 +15,7 @@
  */
 
 #include <folly/io/async/SSLContext.h>
+#include <folly/FileUtil.h>
 #include <folly/portability/GTest.h>
 #include <folly/ssl/OpenSSLPtrTypes.h>
 
@@ -48,4 +49,92 @@ TEST_F(SSLContextTest, TestSetCipherList) {
   ctx.setCipherList(ciphers);
   verifySSLCipherList(ciphers);
 }
+
+TEST_F(SSLContextTest, TestLoadCertKey) {
+  std::string certData, keyData, anotherKeyData;
+  const char* certPath = "folly/io/async/test/certs/tests-cert.pem";
+  const char* keyPath = "folly/io/async/test/certs/tests-key.pem";
+  const char* anotherKeyPath = "folly/io/async/test/certs/client_key.pem";
+  folly::readFile(certPath, certData);
+  folly::readFile(keyPath, keyData);
+  folly::readFile(anotherKeyPath, anotherKeyData);
+
+  {
+    SCOPED_TRACE("Valid cert/key pair from buffer");
+    SSLContext tmpCtx;
+    tmpCtx.loadCertificateFromBufferPEM(certData);
+    tmpCtx.loadPrivateKeyFromBufferPEM(keyData);
+    EXPECT_TRUE(tmpCtx.isCertKeyPairValid());
+  }
+
+  {
+    SCOPED_TRACE("Valid cert/key pair from files");
+    SSLContext tmpCtx;
+    tmpCtx.loadCertificate(certPath);
+    tmpCtx.loadPrivateKey(keyPath);
+    EXPECT_TRUE(tmpCtx.isCertKeyPairValid());
+  }
+
+  {
+    SCOPED_TRACE("Invalid cert/key pair from file. Load cert first");
+    SSLContext tmpCtx;
+    tmpCtx.loadCertificate(certPath);
+    EXPECT_THROW(tmpCtx.loadPrivateKey(anotherKeyPath), std::runtime_error);
+  }
+
+  {
+    SCOPED_TRACE("Invalid cert/key pair from file. Load key first");
+    SSLContext tmpCtx;
+    tmpCtx.loadPrivateKey(anotherKeyPath);
+    tmpCtx.loadCertificate(certPath);
+    EXPECT_FALSE(tmpCtx.isCertKeyPairValid());
+  }
+
+  {
+    SCOPED_TRACE("Invalid key/cert pair from buf. Load cert first");
+    SSLContext tmpCtx;
+    tmpCtx.loadCertificateFromBufferPEM(certData);
+    EXPECT_THROW(
+        tmpCtx.loadPrivateKeyFromBufferPEM(anotherKeyData), std::runtime_error);
+  }
+
+  {
+    SCOPED_TRACE("Invalid key/cert pair from buf. Load key first");
+    SSLContext tmpCtx;
+    tmpCtx.loadPrivateKeyFromBufferPEM(anotherKeyData);
+    tmpCtx.loadCertificateFromBufferPEM(certData);
+    EXPECT_FALSE(tmpCtx.isCertKeyPairValid());
+  }
+
+  {
+    SCOPED_TRACE(
+        "loadCertKeyPairFromBufferPEM() must throw when cert/key mismatch");
+    SSLContext tmpCtx;
+    EXPECT_THROW(
+        tmpCtx.loadCertKeyPairFromBufferPEM(certData, anotherKeyData),
+        std::runtime_error);
+  }
+
+  {
+    SCOPED_TRACE(
+        "loadCertKeyPairFromBufferPEM() must succeed when cert/key match");
+    SSLContext tmpCtx;
+    tmpCtx.loadCertKeyPairFromBufferPEM(certData, keyData);
+  }
+
+  {
+    SCOPED_TRACE(
+        "loadCertKeyPairFromFiles() must throw when cert/key mismatch");
+    SSLContext tmpCtx;
+    EXPECT_THROW(
+        tmpCtx.loadCertKeyPairFromFiles(certPath, anotherKeyPath),
+        std::runtime_error);
+  }
+
+  {
+    SCOPED_TRACE("loadCertKeyPairFromFiles() must succeed when cert/key match");
+    SSLContext tmpCtx;
+    tmpCtx.loadCertKeyPairFromFiles(certPath, keyPath);
+  }
+}
 } // namespace folly