if (getNext() != nullptr) {
writeFlags |= WriteFlags::CORK;
}
+
+ socket_->adjustZeroCopyFlags(getOps(), getOpCount(), writeFlags);
+
auto writeResult = socket_->performWrite(
getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_);
bytesWritten_ = writeResult.writeReturn > 0 ? writeResult.writeReturn : 0;
+ if (bytesWritten_) {
+ if (socket_->isZeroCopyRequest(writeFlags)) {
+ if (isComplete()) {
+ socket_->addZeroCopyBuff(std::move(ioBuf_));
+ } else {
+ socket_->addZeroCopyBuff(ioBuf_.get());
+ }
+ } else {
+ // this happens if at least one of the prev requests were sent
+ // with zero copy but not the last one
+ if (isComplete() && socket_->getZeroCopy() &&
+ socket_->containsZeroCopyBuff(ioBuf_.get())) {
+ socket_->setZeroCopyBuff(std::move(ioBuf_));
+ }
+ }
+ }
return writeResult;
}
opIndex_ += opsWritten_;
assert(opIndex_ < opCount_);
- // If we've finished writing any IOBufs, release them
- if (ioBuf_) {
- for (uint32_t i = opsWritten_; i != 0; --i) {
- assert(ioBuf_);
- ioBuf_ = ioBuf_->pop();
+ if (!socket_->isZeroCopyRequest(flags_)) {
+ // If we've finished writing any IOBufs, release them
+ if (ioBuf_) {
+ for (uint32_t i = opsWritten_; i != 0; --i) {
+ assert(ioBuf_);
+ ioBuf_ = ioBuf_->pop();
+ }
}
}
struct iovec writeOps_[]; ///< write operation(s) list
};
-int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(folly::WriteFlags flags)
- noexcept {
+int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(
+ folly::WriteFlags flags,
+ bool zeroCopyEnabled) noexcept {
int msg_flags = MSG_DONTWAIT;
#ifdef MSG_NOSIGNAL // Linux-only
msg_flags |= MSG_EOR;
}
+ if (zeroCopyEnabled && isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY)) {
+ msg_flags |= MSG_ZEROCOPY;
+ }
+
return msg_flags;
}
// init() method, since constructor forwarding isn't supported in most
// compilers yet.
void AsyncSocket::init() {
- assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
+ if (eventBase_) {
+ eventBase_->dcheckIsInEventBaseThread();
+ }
shutdownFlags_ = 0;
state_ = StateEnum::UNINIT;
eventFlags_ = EventHandler::NONE;
const OptionMap &options,
const folly::SocketAddress& bindAddr) noexcept {
DestructorGuard dg(this);
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
addr_ = address;
// By default, turn on TCP_NODELAY
// If setNoDelay() fails, we continue anyway; this isn't a fatal error.
// setNoDelay() will log an error message if it fails.
+ // Also set the cached zeroCopyVal_ since it cannot be set earlier if the fd
+ // is not created
if (address.getFamily() != AF_UNIX) {
(void)setNoDelay(true);
+ setZeroCopy(zeroCopyVal_);
}
VLOG(5) << "AsyncSocket::connect(this=" << this << ", evb=" << eventBase_
// Ignore return value, errors are ok
setsockopt(fd_, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0);
}
+ if (noTSocks_) {
+ VLOG(4) << "Disabling TSOCKS for fd " << fd_;
+ // Ignore return value, errors are ok
+ setsockopt(fd_, SOL_SOCKET, SO_NO_TSOCKS, nullptr, 0);
+ }
#endif
int rv = fsp::connect(fd_, saddr, len);
if (rv < 0) {
void AsyncSocket::setSendTimeout(uint32_t milliseconds) {
sendTimeout_ = milliseconds;
- assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
+ if (eventBase_) {
+ eventBase_->dcheckIsInEventBaseThread();
+ }
// If we are currently pending on write requests, immediately update
// writeTimeout_ with the new value.
}
DestructorGuard dg(this);
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
if (callback == nullptr) {
// We should be able to reset the callback regardless of the
}
DestructorGuard dg(this);
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
switch ((StateEnum)state_) {
case StateEnum::CONNECTING:
return readCallback_;
}
+bool AsyncSocket::setZeroCopy(bool enable) {
+ if (msgErrQueueSupported) {
+ zeroCopyVal_ = enable;
+
+ if (fd_ < 0) {
+ return false;
+ }
+
+ int val = enable ? 1 : 0;
+ int ret = setsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val));
+
+ // if enable == false, set zeroCopyEnabled_ = false regardless
+ // if SO_ZEROCOPY is set or not
+ if (!enable) {
+ zeroCopyEnabled_ = enable;
+ return true;
+ }
+
+ /* if the setsockopt failed, try to see if the socket inherited the flag
+ * since we cannot set SO_ZEROCOPY on a socket s = accept
+ */
+ if (ret) {
+ val = 0;
+ socklen_t optlen = sizeof(val);
+ ret = getsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, &optlen);
+
+ if (!ret) {
+ enable = val ? true : false;
+ }
+ }
+
+ if (!ret) {
+ zeroCopyEnabled_ = enable;
+
+ return true;
+ }
+ }
+
+ return false;
+}
+
+void AsyncSocket::setZeroCopyWriteChainThreshold(size_t threshold) {
+ zeroCopyWriteChainThreshold_ = threshold;
+}
+
+bool AsyncSocket::isZeroCopyRequest(WriteFlags flags) {
+ return (zeroCopyEnabled_ && isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY));
+}
+
+void AsyncSocket::adjustZeroCopyFlags(
+ folly::IOBuf* buf,
+ folly::WriteFlags& flags) {
+ if (zeroCopyEnabled_ && zeroCopyWriteChainThreshold_ && buf) {
+ if (buf->computeChainDataLength() >= zeroCopyWriteChainThreshold_) {
+ flags |= folly::WriteFlags::WRITE_MSG_ZEROCOPY;
+ } else {
+ flags = unSet(flags, folly::WriteFlags::WRITE_MSG_ZEROCOPY);
+ }
+ }
+}
+
+void AsyncSocket::adjustZeroCopyFlags(
+ const iovec* vec,
+ uint32_t count,
+ folly::WriteFlags& flags) {
+ if (zeroCopyEnabled_ && zeroCopyWriteChainThreshold_) {
+ count = std::min<uint32_t>(count, kIovMax);
+ size_t sum = 0;
+ for (uint32_t i = 0; i < count; ++i) {
+ const iovec* v = vec + i;
+ sum += v->iov_len;
+ }
+
+ if (sum >= zeroCopyWriteChainThreshold_) {
+ flags |= folly::WriteFlags::WRITE_MSG_ZEROCOPY;
+ } else {
+ flags = unSet(flags, folly::WriteFlags::WRITE_MSG_ZEROCOPY);
+ }
+ }
+}
+
+void AsyncSocket::addZeroCopyBuff(std::unique_ptr<folly::IOBuf>&& buf) {
+ uint32_t id = getNextZeroCopyBuffId();
+ folly::IOBuf* ptr = buf.get();
+
+ idZeroCopyBufPtrMap_[id] = ptr;
+ auto& p = idZeroCopyBufPtrToBufMap_[ptr];
+ p.first++;
+ CHECK(p.second.get() == nullptr);
+ p.second = std::move(buf);
+}
+
+void AsyncSocket::addZeroCopyBuff(folly::IOBuf* ptr) {
+ uint32_t id = getNextZeroCopyBuffId();
+ idZeroCopyBufPtrMap_[id] = ptr;
+
+ idZeroCopyBufPtrToBufMap_[ptr].first++;
+}
+
+void AsyncSocket::releaseZeroCopyBuff(uint32_t id) {
+ auto iter = idZeroCopyBufPtrMap_.find(id);
+ CHECK(iter != idZeroCopyBufPtrMap_.end());
+ auto ptr = iter->second;
+ auto iter1 = idZeroCopyBufPtrToBufMap_.find(ptr);
+ CHECK(iter1 != idZeroCopyBufPtrToBufMap_.end());
+ if (0 == --iter1->second.first) {
+ idZeroCopyBufPtrToBufMap_.erase(iter1);
+ }
+}
+
+void AsyncSocket::setZeroCopyBuff(std::unique_ptr<folly::IOBuf>&& buf) {
+ folly::IOBuf* ptr = buf.get();
+ auto& p = idZeroCopyBufPtrToBufMap_[ptr];
+ CHECK(p.second.get() == nullptr);
+
+ p.second = std::move(buf);
+}
+
+bool AsyncSocket::containsZeroCopyBuff(folly::IOBuf* ptr) {
+ return (
+ idZeroCopyBufPtrToBufMap_.find(ptr) != idZeroCopyBufPtrToBufMap_.end());
+}
+
+bool AsyncSocket::isZeroCopyMsg(const cmsghdr& cmsg) const {
+#ifdef MSG_ERRQUEUE
+ if (zeroCopyEnabled_ &&
+ ((cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
+ (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR))) {
+ const struct sock_extended_err* serr =
+ reinterpret_cast<const struct sock_extended_err*>(CMSG_DATA(&cmsg));
+ return (
+ (serr->ee_errno == 0) && (serr->ee_origin == SO_EE_ORIGIN_ZEROCOPY));
+ }
+#endif
+ return false;
+}
+
+void AsyncSocket::processZeroCopyMsg(const cmsghdr& cmsg) {
+#ifdef MSG_ERRQUEUE
+ const struct sock_extended_err* serr =
+ reinterpret_cast<const struct sock_extended_err*>(CMSG_DATA(&cmsg));
+ uint32_t hi = serr->ee_data;
+ uint32_t lo = serr->ee_info;
+
+ for (uint32_t i = lo; i <= hi; i++) {
+ releaseZeroCopyBuff(i);
+ }
+#endif
+}
+
void AsyncSocket::write(WriteCallback* callback,
const void* buf, size_t bytes, WriteFlags flags) {
iovec op;
void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr<IOBuf>&& buf,
WriteFlags flags) {
+ adjustZeroCopyFlags(buf.get(), flags);
+
constexpr size_t kSmallSizeMax = 64;
size_t count = buf->countChainElements();
if (count <= kSmallSizeMax) {
<< ", state=" << state_;
DestructorGuard dg(this);
unique_ptr<IOBuf>ioBuf(std::move(buf));
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
if (shutdownFlags_ & (SHUT_WRITE | SHUT_WRITE_PENDING)) {
// No new writes may be performed after the write side of the socket has
errnoCopy);
return failWrite(__func__, callback, 0, ex);
} else if (countWritten == count) {
+ // done, add the whole buffer
+ if (isZeroCopyRequest(flags)) {
+ addZeroCopyBuff(std::move(ioBuf));
+ }
// We successfully wrote everything.
// Invoke the callback and return.
if (callback) {
}
return;
} else { // continue writing the next writeReq
+ // add just the ptr
+ if (isZeroCopyRequest(flags)) {
+ addZeroCopyBuff(ioBuf.get());
+ }
if (bufferCallback_) {
bufferCallback_->onEgressBuffered();
}
// Declare a DestructorGuard to ensure that the AsyncSocket cannot be
// destroyed until close() returns.
DestructorGuard dg(this);
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
// Since there are write requests pending, we have to set the
// SHUT_WRITE_PENDING flag, and wait to perform the real close until the
<< ", state=" << state_ << ", shutdownFlags="
<< std::hex << (int) shutdownFlags_;
DestructorGuard dg(this);
- assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
+ if (eventBase_) {
+ eventBase_->dcheckIsInEventBaseThread();
+ }
switch (state_) {
case StateEnum::ESTABLISHED:
return;
}
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
// There are pending writes. Set SHUT_WRITE_PENDING so that the actual
// shutdown will be performed once all writes complete.
}
DestructorGuard dg(this);
- assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
+ if (eventBase_) {
+ eventBase_->dcheckIsInEventBaseThread();
+ }
switch (static_cast<StateEnum>(state_)) {
case StateEnum::ESTABLISHED:
<< ", state=" << state_ << ", events="
<< std::hex << eventFlags_ << ")";
assert(eventBase_ == nullptr);
- assert(eventBase->isInEventBaseThread());
+ eventBase->dcheckIsInEventBaseThread();
eventBase_ = eventBase;
ioHandler_.attachEventBase(eventBase);
<< ", old evb=" << eventBase_ << ", state=" << state_
<< ", events=" << std::hex << eventFlags_ << ")";
assert(eventBase_ != nullptr);
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
eventBase_ = nullptr;
ioHandler_.detachEventBase();
bool AsyncSocket::isDetachable() const {
DCHECK(eventBase_ != nullptr);
- DCHECK(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
return !ioHandler_.isHandlerRegistered() && !writeTimeout_.isScheduled();
}
-void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const {
+void AsyncSocket::cacheAddresses() {
+ if (fd_ >= 0) {
+ try {
+ cacheLocalAddress();
+ cachePeerAddress();
+ } catch (const std::system_error& e) {
+ if (e.code() != std::error_code(ENOTCONN, std::system_category())) {
+ VLOG(1) << "Error caching addresses: " << e.code().value() << ", "
+ << e.code().message();
+ }
+ }
+ }
+}
+
+void AsyncSocket::cacheLocalAddress() const {
if (!localAddr_.isInitialized()) {
localAddr_.setFromLocalAddress(fd_);
}
- *address = localAddr_;
}
-void AsyncSocket::getPeerAddress(folly::SocketAddress* address) const {
+void AsyncSocket::cachePeerAddress() const {
if (!addr_.isInitialized()) {
addr_.setFromPeerAddress(fd_);
}
+}
+
+void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const {
+ cacheLocalAddress();
+ *address = localAddr_;
+}
+
+void AsyncSocket::getPeerAddress(folly::SocketAddress* address) const {
+ cachePeerAddress();
*address = addr_;
}
<< ", events=" << std::hex << events << ", state=" << state_;
DestructorGuard dg(this);
assert(events & EventHandler::READ_WRITE);
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
uint16_t relevantEvents = uint16_t(events & EventHandler::READ_WRITE);
EventBase* originalEventBase = eventBase_;
// supporting per-socket error queues.
VLOG(5) << "AsyncSocket::handleErrMessages() this=" << this << ", fd=" << fd_
<< ", state=" << state_;
- if (errMessageCallback_ == nullptr) {
+ if (errMessageCallback_ == nullptr &&
+ (!zeroCopyEnabled_ || idZeroCopyBufPtrMap_.empty())) {
VLOG(7) << "AsyncSocket::handleErrMessages(): "
<< "no callback installed - exiting.";
return;
for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
cmsg != nullptr && cmsg->cmsg_len != 0;
cmsg = CMSG_NXTHDR(&msg, cmsg)) {
- errMessageCallback_->errMessage(*cmsg);
+ if (isZeroCopyMsg(*cmsg)) {
+ processZeroCopyMsg(*cmsg);
+ } else {
+ if (errMessageCallback_) {
+ errMessageCallback_->errMessage(*cmsg);
+ }
+ }
}
}
#endif //MSG_ERRQUEUE
VLOG(7) << "AsyncSocket " << this << ", fd " << fd_ << ": timeout expired: "
<< "state=" << state_ << ", events=" << std::hex << eventFlags_;
DestructorGuard dg(this);
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
if (state_ == StateEnum::CONNECTING) {
// connect() timed out
} else {
msg.msg_control = nullptr;
}
- int msg_flags = sendMsgParamCallback_->getFlags(flags);
+ int msg_flags = sendMsgParamCallback_->getFlags(flags, zeroCopyEnabled_);
auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
auto totalWritten = writeResult.writeReturn;
* and call all currently installed callbacks. After an error, the
* AsyncSocket is completely unregistered.
*
- * @return Returns true on succcess, or false on error.
+ * @return Returns true on success, or false on error.
*/
bool AsyncSocket::updateEventRegistration() {
VLOG(5) << "AsyncSocket::updateEventRegistration(this=" << this
<< ", fd=" << fd_ << ", evb=" << eventBase_ << ", state=" << state_
<< ", events=" << std::hex << eventFlags_;
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
if (eventFlags_ == EventHandler::NONE) {
ioHandler_.unregisterHandler();
return true;
bufferCallback_ = cb;
}
-} // folly
+} // namespace folly