Add SO_ZEROCOPY support
[folly.git] / folly / io / async / AsyncSocket.cpp
index 7f8c5f13614a73f76c8a57058486a26a5bbefb69..abbec9b995ff8854789b7927f768832d1d9e2b18 100644 (file)
@@ -104,9 +104,28 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
     if (getNext() != nullptr) {
       writeFlags |= WriteFlags::CORK;
     }
+
+    socket_->adjustZeroCopyFlags(getOps(), getOpCount(), writeFlags);
+
     auto writeResult = socket_->performWrite(
         getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_);
     bytesWritten_ = writeResult.writeReturn > 0 ? writeResult.writeReturn : 0;
+    if (bytesWritten_) {
+      if (socket_->isZeroCopyRequest(writeFlags)) {
+        if (isComplete()) {
+          socket_->addZeroCopyBuff(std::move(ioBuf_));
+        } else {
+          socket_->addZeroCopyBuff(ioBuf_.get());
+        }
+      } else {
+        // this happens if at least one of the prev requests were sent
+        // with zero copy but not the last one
+        if (isComplete() && socket_->getZeroCopy() &&
+            socket_->containsZeroCopyBuff(ioBuf_.get())) {
+          socket_->setZeroCopyBuff(std::move(ioBuf_));
+        }
+      }
+    }
     return writeResult;
   }
 
@@ -119,11 +138,13 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
     opIndex_ += opsWritten_;
     assert(opIndex_ < opCount_);
 
-    // If we've finished writing any IOBufs, release them
-    if (ioBuf_) {
-      for (uint32_t i = opsWritten_; i != 0; --i) {
-        assert(ioBuf_);
-        ioBuf_ = ioBuf_->pop();
+    if (!socket_->isZeroCopyRequest(flags_)) {
+      // If we've finished writing any IOBufs, release them
+      if (ioBuf_) {
+        for (uint32_t i = opsWritten_; i != 0; --i) {
+          assert(ioBuf_);
+          ioBuf_ = ioBuf_->pop();
+        }
       }
     }
 
@@ -185,8 +206,9 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
   struct iovec writeOps_[];     ///< write operation(s) list
 };
 
-int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(folly::WriteFlags flags)
-                                                                      noexcept {
+int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(
+    folly::WriteFlags flags,
+    bool zeroCopyEnabled) noexcept {
   int msg_flags = MSG_DONTWAIT;
 
 #ifdef MSG_NOSIGNAL // Linux-only
@@ -205,6 +227,10 @@ int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(folly::WriteFlags flags)
     msg_flags |= MSG_EOR;
   }
 
+  if (zeroCopyEnabled && isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY)) {
+    msg_flags |= MSG_ZEROCOPY;
+  }
+
   return msg_flags;
 }
 
@@ -433,8 +459,11 @@ void AsyncSocket::connect(ConnectCallback* callback,
     // By default, turn on TCP_NODELAY
     // If setNoDelay() fails, we continue anyway; this isn't a fatal error.
     // setNoDelay() will log an error message if it fails.
+    // Also set the cached zeroCopyVal_ since it cannot be set earlier if the fd
+    // is not created
     if (address.getFamily() != AF_UNIX) {
       (void)setNoDelay(true);
+      setZeroCopy(zeroCopyVal_);
     }
 
     VLOG(5) << "AsyncSocket::connect(this=" << this << ", evb=" << eventBase_
@@ -772,6 +801,156 @@ AsyncSocket::ReadCallback* AsyncSocket::getReadCallback() const {
   return readCallback_;
 }
 
+bool AsyncSocket::setZeroCopy(bool enable) {
+  if (msgErrQueueSupported) {
+    zeroCopyVal_ = enable;
+
+    if (fd_ < 0) {
+      return false;
+    }
+
+    int val = enable ? 1 : 0;
+    int ret = setsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val));
+
+    // if enable == false, set zeroCopyEnabled_ = false regardless
+    // if SO_ZEROCOPY is set or not
+    if (!enable) {
+      zeroCopyEnabled_ = enable;
+      return true;
+    }
+
+    /* if the setsockopt failed, try to see if the socket inherited the flag
+     * since we cannot set SO_ZEROCOPY on a socket s = accept
+     */
+    if (ret) {
+      val = 0;
+      socklen_t optlen = sizeof(val);
+      ret = getsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, &optlen);
+
+      if (!ret) {
+        enable = val ? true : false;
+      }
+    }
+
+    if (!ret) {
+      zeroCopyEnabled_ = enable;
+
+      return true;
+    }
+  }
+
+  return false;
+}
+
+void AsyncSocket::setZeroCopyWriteChainThreshold(size_t threshold) {
+  zeroCopyWriteChainThreshold_ = threshold;
+}
+
+bool AsyncSocket::isZeroCopyRequest(WriteFlags flags) {
+  return (zeroCopyEnabled_ && isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY));
+}
+
+void AsyncSocket::adjustZeroCopyFlags(
+    folly::IOBuf* buf,
+    folly::WriteFlags& flags) {
+  if (zeroCopyEnabled_ && zeroCopyWriteChainThreshold_ && buf) {
+    if (buf->computeChainDataLength() >= zeroCopyWriteChainThreshold_) {
+      flags |= folly::WriteFlags::WRITE_MSG_ZEROCOPY;
+    } else {
+      flags = unSet(flags, folly::WriteFlags::WRITE_MSG_ZEROCOPY);
+    }
+  }
+}
+
+void AsyncSocket::adjustZeroCopyFlags(
+    const iovec* vec,
+    uint32_t count,
+    folly::WriteFlags& flags) {
+  if (zeroCopyEnabled_ && zeroCopyWriteChainThreshold_) {
+    count = std::min<uint32_t>(count, kIovMax);
+    size_t sum = 0;
+    for (uint32_t i = 0; i < count; ++i) {
+      const iovec* v = vec + i;
+      sum += v->iov_len;
+    }
+
+    if (sum >= zeroCopyWriteChainThreshold_) {
+      flags |= folly::WriteFlags::WRITE_MSG_ZEROCOPY;
+    } else {
+      flags = unSet(flags, folly::WriteFlags::WRITE_MSG_ZEROCOPY);
+    }
+  }
+}
+
+void AsyncSocket::addZeroCopyBuff(std::unique_ptr<folly::IOBuf>&& buf) {
+  uint32_t id = getNextZeroCopyBuffId();
+  folly::IOBuf* ptr = buf.get();
+
+  idZeroCopyBufPtrMap_[id] = ptr;
+  auto& p = idZeroCopyBufPtrToBufMap_[ptr];
+  p.first++;
+  CHECK(p.second.get() == nullptr);
+  p.second = std::move(buf);
+}
+
+void AsyncSocket::addZeroCopyBuff(folly::IOBuf* ptr) {
+  uint32_t id = getNextZeroCopyBuffId();
+  idZeroCopyBufPtrMap_[id] = ptr;
+
+  idZeroCopyBufPtrToBufMap_[ptr].first++;
+}
+
+void AsyncSocket::releaseZeroCopyBuff(uint32_t id) {
+  auto iter = idZeroCopyBufPtrMap_.find(id);
+  CHECK(iter != idZeroCopyBufPtrMap_.end());
+  auto ptr = iter->second;
+  auto iter1 = idZeroCopyBufPtrToBufMap_.find(ptr);
+  CHECK(iter1 != idZeroCopyBufPtrToBufMap_.end());
+  if (0 == --iter1->second.first) {
+    idZeroCopyBufPtrToBufMap_.erase(iter1);
+  }
+}
+
+void AsyncSocket::setZeroCopyBuff(std::unique_ptr<folly::IOBuf>&& buf) {
+  folly::IOBuf* ptr = buf.get();
+  auto& p = idZeroCopyBufPtrToBufMap_[ptr];
+  CHECK(p.second.get() == nullptr);
+
+  p.second = std::move(buf);
+}
+
+bool AsyncSocket::containsZeroCopyBuff(folly::IOBuf* ptr) {
+  return (
+      idZeroCopyBufPtrToBufMap_.find(ptr) != idZeroCopyBufPtrToBufMap_.end());
+}
+
+bool AsyncSocket::isZeroCopyMsg(const cmsghdr& cmsg) const {
+#ifdef MSG_ERRQUEUE
+  if (zeroCopyEnabled_ &&
+      ((cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
+       (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR))) {
+    const struct sock_extended_err* serr =
+        reinterpret_cast<const struct sock_extended_err*>(CMSG_DATA(&cmsg));
+    return (
+        (serr->ee_errno == 0) && (serr->ee_origin == SO_EE_ORIGIN_ZEROCOPY));
+  }
+#endif
+  return false;
+}
+
+void AsyncSocket::processZeroCopyMsg(const cmsghdr& cmsg) {
+#ifdef MSG_ERRQUEUE
+  const struct sock_extended_err* serr =
+      reinterpret_cast<const struct sock_extended_err*>(CMSG_DATA(&cmsg));
+  uint32_t hi = serr->ee_data;
+  uint32_t lo = serr->ee_info;
+
+  for (uint32_t i = lo; i <= hi; i++) {
+    releaseZeroCopyBuff(i);
+  }
+#endif
+}
+
 void AsyncSocket::write(WriteCallback* callback,
                          const void* buf, size_t bytes, WriteFlags flags) {
   iovec op;
@@ -789,6 +968,8 @@ void AsyncSocket::writev(WriteCallback* callback,
 
 void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr<IOBuf>&& buf,
                               WriteFlags flags) {
+  adjustZeroCopyFlags(buf.get(), flags);
+
   constexpr size_t kSmallSizeMax = 64;
   size_t count = buf->countChainElements();
   if (count <= kSmallSizeMax) {
@@ -860,6 +1041,10 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
             errnoCopy);
         return failWrite(__func__, callback, 0, ex);
       } else if (countWritten == count) {
+        // done, add the whole buffer
+        if (isZeroCopyRequest(flags)) {
+          addZeroCopyBuff(std::move(ioBuf));
+        }
         // We successfully wrote everything.
         // Invoke the callback and return.
         if (callback) {
@@ -867,6 +1052,10 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
         }
         return;
       } else { // continue writing the next writeReq
+        // add just the ptr
+        if (isZeroCopyRequest(flags)) {
+          addZeroCopyBuff(ioBuf.get());
+        }
         if (bufferCallback_) {
           bufferCallback_->onEgressBuffered();
         }
@@ -1545,7 +1734,8 @@ void AsyncSocket::handleErrMessages() noexcept {
   // supporting per-socket error queues.
   VLOG(5) << "AsyncSocket::handleErrMessages() this=" << this << ", fd=" << fd_
           << ", state=" << state_;
-  if (errMessageCallback_ == nullptr) {
+  if (errMessageCallback_ == nullptr &&
+      (!zeroCopyEnabled_ || idZeroCopyBufPtrMap_.empty())) {
     VLOG(7) << "AsyncSocket::handleErrMessages(): "
             << "no callback installed - exiting.";
     return;
@@ -1587,11 +1777,15 @@ void AsyncSocket::handleErrMessages() noexcept {
     }
 
     for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
-         cmsg != nullptr &&
-           cmsg->cmsg_len != 0 &&
-           errMessageCallback_ != nullptr;
+         cmsg != nullptr && cmsg->cmsg_len != 0;
          cmsg = CMSG_NXTHDR(&msg, cmsg)) {
-      errMessageCallback_->errMessage(*cmsg);
+      if (isZeroCopyMsg(*cmsg)) {
+        processZeroCopyMsg(*cmsg);
+      } else {
+        if (errMessageCallback_) {
+          errMessageCallback_->errMessage(*cmsg);
+        }
+      }
     }
   }
 #endif //MSG_ERRQUEUE
@@ -2127,7 +2321,7 @@ AsyncSocket::WriteResult AsyncSocket::performWrite(
   } else {
     msg.msg_control = nullptr;
   }
-  int msg_flags = sendMsgParamCallback_->getFlags(flags);
+  int msg_flags = sendMsgParamCallback_->getFlags(flags, zeroCopyEnabled_);
 
   auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
   auto totalWritten = writeResult.writeReturn;