Add handshake and connect times
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
index f4a443214cff1cc4965be1509324603aaabeda5e..e7a556c0d92b6aa5a546e954031c485bba66de0b 100644 (file)
@@ -84,8 +84,7 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
   int64_t startTime_;
 
  protected:
-  virtual ~AsyncSSLSocketConnector() {
-  }
+  ~AsyncSSLSocketConnector() override {}
 
  public:
   AsyncSSLSocketConnector(AsyncSSLSocket *sslSocket,
@@ -98,7 +97,7 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
                    std::chrono::steady_clock::now().time_since_epoch()).count()) {
   }
 
-  virtual void connectSuccess() noexcept {
+  void connectSuccess() noexcept override {
     VLOG(7) << "client socket connected";
 
     int64_t timeoutLeft = 0;
@@ -118,13 +117,13 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
     sslSocket_->sslConn(this, timeoutLeft);
   }
 
-  virtual void connectErr(const AsyncSocketException& ex) noexcept {
+  void connectErr(const AsyncSocketException& ex) noexcept override {
     LOG(ERROR) << "TCP connect failed: " <<  ex.what();
     fail(ex);
     delete this;
   }
 
-  virtual void handshakeSuc(AsyncSSLSocket *sock) noexcept {
+  void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
     VLOG(7) << "client handshake success";
     if (callback_) {
       callback_->connectSuccess();
@@ -132,8 +131,8 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
     delete this;
   }
 
-  virtual void handshakeErr(AsyncSSLSocket *socket,
-                              const AsyncSocketException& ex) noexcept {
+  void handshakeErr(AsyncSSLSocket* socket,
+                    const AsyncSocketException& ex) noexcept override {
     LOG(ERROR) << "client handshakeErr: " << ex.what();
     fail(ex);
     delete this;
@@ -327,6 +326,8 @@ void AsyncSSLSocket::init() {
   // Do this here to ensure we initialize this once before any use of
   // AsyncSSLSocket instances and not as part of library load.
   static const auto eorAwareBioMethodInitializer = initEorBioMethod();
+  (void)eorAwareBioMethodInitializer;
+
   setup_SSL_CTX(ctx_->getSSLCtx());
 }
 
@@ -355,13 +356,10 @@ void AsyncSSLSocket::closeNow() {
 
   DestructorGuard dg(this);
 
-  if (handshakeCallback_) {
-    AsyncSocketException ex(AsyncSocketException::END_OF_FILE,
-                           "SSL connection closed locally");
-    HandshakeCB* callback = handshakeCallback_;
-    handshakeCallback_ = nullptr;
-    callback->handshakeErr(this, ex);
-  }
+  invokeHandshakeErr(
+      AsyncSocketException(
+        AsyncSocketException::END_OF_FILE,
+        "SSL connection closed locally"));
 
   if (ssl_ != nullptr) {
     SSL_free(ssl_);
@@ -467,6 +465,7 @@ void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
   AsyncSocketException ex(AsyncSocketException::INVALID_STATE,
                          "sslAccept() called with socket in invalid state");
 
+  handshakeEndTime_ = std::chrono::steady_clock::now();
   if (callback) {
     callback->handshakeErr(this, ex);
   }
@@ -489,6 +488,9 @@ void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
       handshakeCallback_ != nullptr) {
     return invalidState(callback);
   }
+  handshakeStartTime_ = std::chrono::steady_clock::now();
+  // Make end time at least >= start time.
+  handshakeEndTime_ = handshakeStartTime_;
 
   sslState_ = STATE_ACCEPTING;
   handshakeCallback_ = callback;
@@ -608,18 +610,10 @@ void AsyncSSLSocket::timeoutExpired() noexcept {
   }
 }
 
-int AsyncSSLSocket::sslExDataIndex_ = -1;
-std::mutex AsyncSSLSocket::mutex_;
-
 int AsyncSSLSocket::getSSLExDataIndex() {
-  if (sslExDataIndex_ < 0) {
-    std::lock_guard<std::mutex> g(mutex_);
-    if (sslExDataIndex_ < 0) {
-      sslExDataIndex_ = SSL_get_ex_new_index(0,
-          (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
-    }
-  }
-  return sslExDataIndex_;
+  static auto index = SSL_get_ex_new_index(
+      0, (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
+  return index;
 }
 
 AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
@@ -630,20 +624,24 @@ AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
 void AsyncSSLSocket::failHandshake(const char* fn,
                                     const AsyncSocketException& ex) {
   startFail();
-
   if (handshakeTimeout_.isScheduled()) {
     handshakeTimeout_.cancelTimeout();
   }
+  invokeHandshakeErr(ex);
+  finishFail();
+}
+
+void AsyncSSLSocket::invokeHandshakeErr(const AsyncSocketException& ex) {
+  handshakeEndTime_ = std::chrono::steady_clock::now();
   if (handshakeCallback_ != nullptr) {
     HandshakeCB* callback = handshakeCallback_;
     handshakeCallback_ = nullptr;
     callback->handshakeErr(this, ex);
   }
-
-  finishFail();
 }
 
 void AsyncSSLSocket::invokeHandshakeCB() {
+  handshakeEndTime_ = std::chrono::steady_clock::now();
   if (handshakeTimeout_.isScheduled()) {
     handshakeTimeout_.cancelTimeout();
   }
@@ -698,6 +696,10 @@ void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
     return invalidState(callback);
   }
 
+  handshakeStartTime_ = std::chrono::steady_clock::now();
+  // Make end time at least >= start time.
+  handshakeEndTime_ = handshakeStartTime_;
+
   sslState_ = STATE_CONNECTING;
   handshakeCallback_ = callback;
 
@@ -1071,6 +1073,29 @@ AsyncSSLSocket::handleConnect() noexcept {
   AsyncSocket::handleInitialReadWrite();
 }
 
+void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
+#ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
+  // turn on the buffer movable in openssl
+  if (!isBufferMovable_ && callback != nullptr && callback->isBufferMovable()) {
+    SSL_set_mode(ssl_, SSL_get_mode(ssl_) | SSL_MODE_MOVE_BUFFER_OWNERSHIP);
+    isBufferMovable_ = true;
+  }
+#endif
+
+  AsyncSocket::setReadCB(callback);
+}
+
+void AsyncSSLSocket::prepareReadBuffer(void** buf, size_t* buflen) noexcept {
+  CHECK(readCallback_);
+  if (isBufferMovable_) {
+    *buf = nullptr;
+    *buflen = 0;
+  } else {
+    // buf is necessary for SSLSocket without SSL_MODE_MOVE_BUFFER_OWNERSHIP
+    readCallback_->getReadBuffer(buf, buflen);
+  }
+}
+
 void
 AsyncSSLSocket::handleRead() noexcept {
   VLOG(5) << "AsyncSSLSocket::handleRead() this=" << this << ", fd=" << fd_
@@ -1097,13 +1122,25 @@ AsyncSSLSocket::handleRead() noexcept {
 }
 
 ssize_t
-AsyncSSLSocket::performRead(void* buf, size_t buflen) {
+AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
+  VLOG(4) << "AsyncSSLSocket::performRead() this=" << this
+          << ", buf=" << *buf << ", buflen=" << *buflen;
+
   if (sslState_ == STATE_UNENCRYPTED) {
-    return AsyncSocket::performRead(buf, buflen);
+    return AsyncSocket::performRead(buf, buflen, offset);
   }
 
   errno = 0;
-  ssize_t bytes = SSL_read(ssl_, buf, buflen);
+  ssize_t bytes = 0;
+  if (!isBufferMovable_) {
+    bytes = SSL_read(ssl_, *buf, *buflen);
+  }
+#ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
+  else {
+    bytes = SSL_read_buf(ssl_, buf, (int *) offset, (int *) buflen);
+  }
+#endif
+
   if (server_ && renegotiateAttempted_) {
     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
                << ", sslstate=" << sslState_ << ", events=" << eventFlags_
@@ -1572,13 +1609,30 @@ AsyncSSLSocket::clientHelloParsingCallback(int written, int version,
     if (cursor.totalLength() > 0) {
       uint16_t extensionsLength = cursor.readBE<uint16_t>();
       while (extensionsLength) {
+        TLSExtension extensionType = static_cast<TLSExtension>(
+            cursor.readBE<uint16_t>());
         sock->clientHelloInfo_->
-          clientHelloExtensions_.push_back(cursor.readBE<uint16_t>());
+          clientHelloExtensions_.push_back(extensionType);
         extensionsLength -= 2;
         uint16_t extensionDataLength = cursor.readBE<uint16_t>();
         extensionsLength -= 2;
-        cursor.skip(extensionDataLength);
-        extensionsLength -= extensionDataLength;
+
+        if (extensionType == TLSExtension::SIGNATURE_ALGORITHMS) {
+          cursor.skip(2);
+          extensionDataLength -= 2;
+          while (extensionDataLength) {
+            HashAlgorithm hashAlg = static_cast<HashAlgorithm>(
+                cursor.readBE<uint8_t>());
+            SignatureAlgorithm sigAlg = static_cast<SignatureAlgorithm>(
+                cursor.readBE<uint8_t>());
+            extensionDataLength -= 2;
+            sock->clientHelloInfo_->
+              clientHelloSigAlgs_.emplace_back(hashAlg, sigAlg);
+          }
+        } else {
+          cursor.skip(extensionDataLength);
+          extensionsLength -= extensionDataLength;
+        }
       }
     }
   } catch (std::out_of_range& e) {