/*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
#include <folly/io/async/AsyncSocket.h>
-#include <folly/io/async/EventBase.h>
-#include <folly/io/async/EventHandler.h>
+#include <folly/ExceptionWrapper.h>
+#include <folly/Portability.h>
#include <folly/SocketAddress.h>
+#include <folly/io/Cursor.h>
#include <folly/io/IOBuf.h>
+#include <folly/io/IOBufQueue.h>
+#include <folly/portability/Fcntl.h>
+#include <folly/portability/Sockets.h>
#include <folly/portability/SysUio.h>
+#include <folly/portability/Unistd.h>
-#include <poll.h>
+#include <boost/preprocessor/control/if.hpp>
#include <errno.h>
#include <limits.h>
-#include <unistd.h>
-#include <thread>
-#include <fcntl.h>
#include <sys/types.h>
-#include <sys/socket.h>
-#include <netinet/in.h>
-#include <netinet/tcp.h>
-#include <boost/preprocessor/control/if.hpp>
+#include <thread>
using std::string;
using std::unique_ptr;
+namespace fsp = folly::portability::sockets;
+
namespace folly {
+static constexpr bool msgErrQueueSupported =
+#ifdef MSG_ERRQUEUE
+ true;
+#else
+ false;
+#endif // MSG_ERRQUEUE
+
// static members initializers
const AsyncSocket::OptionMap AsyncSocket::emptyOptionMap;
free(this);
}
- bool performWrite() override {
+ WriteResult performWrite() override {
WriteFlags writeFlags = flags_;
if (getNext() != nullptr) {
- writeFlags = writeFlags | WriteFlags::CORK;
+ writeFlags |= WriteFlags::CORK;
}
- bytesWritten_ = socket_->performWrite(getOps(), getOpCount(), writeFlags,
- &opsWritten_, &partialBytes_);
- return bytesWritten_ >= 0;
+
+ socket_->adjustZeroCopyFlags(getOps(), getOpCount(), writeFlags);
+
+ auto writeResult = socket_->performWrite(
+ getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_);
+ bytesWritten_ = writeResult.writeReturn > 0 ? writeResult.writeReturn : 0;
+ if (bytesWritten_) {
+ if (socket_->isZeroCopyRequest(writeFlags)) {
+ if (isComplete()) {
+ socket_->addZeroCopyBuff(std::move(ioBuf_));
+ } else {
+ socket_->addZeroCopyBuff(ioBuf_.get());
+ }
+ } else {
+ // this happens if at least one of the prev requests were sent
+ // with zero copy but not the last one
+ if (isComplete() && socket_->getZeroCopy() &&
+ socket_->containsZeroCopyBuff(ioBuf_.get())) {
+ socket_->setZeroCopyBuff(std::move(ioBuf_));
+ }
+ }
+ }
+ return writeResult;
}
bool isComplete() override {
opIndex_ += opsWritten_;
assert(opIndex_ < opCount_);
- // If we've finished writing any IOBufs, release them
- if (ioBuf_) {
- for (uint32_t i = opsWritten_; i != 0; --i) {
- assert(ioBuf_);
- ioBuf_ = ioBuf_->pop();
+ if (!socket_->isZeroCopyRequest(flags_)) {
+ // If we've finished writing any IOBufs, release them
+ if (ioBuf_) {
+ for (uint32_t i = opsWritten_; i != 0; --i) {
+ assert(ioBuf_);
+ ioBuf_ = ioBuf_->pop();
+ }
}
}
currentOp->iov_len -= partialBytes_;
// Increment the totalBytesWritten_ count by bytesWritten_;
- totalBytesWritten_ += bytesWritten_;
+ assert(bytesWritten_ >= 0);
+ totalBytesWritten_ += uint32_t(bytesWritten_);
}
private:
struct iovec writeOps_[]; ///< write operation(s) list
};
+int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(
+ folly::WriteFlags flags,
+ bool zeroCopyEnabled) noexcept {
+ int msg_flags = MSG_DONTWAIT;
+
+#ifdef MSG_NOSIGNAL // Linux-only
+ msg_flags |= MSG_NOSIGNAL;
+#ifdef MSG_MORE
+ if (isSet(flags, WriteFlags::CORK)) {
+ // MSG_MORE tells the kernel we have more data to send, so wait for us to
+ // give it the rest of the data rather than immediately sending a partial
+ // frame, even when TCP_NODELAY is enabled.
+ msg_flags |= MSG_MORE;
+ }
+#endif // MSG_MORE
+#endif // MSG_NOSIGNAL
+ if (isSet(flags, WriteFlags::EOR)) {
+ // marks that this is the last byte of a record (response)
+ msg_flags |= MSG_EOR;
+ }
+
+ if (zeroCopyEnabled && isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY)) {
+ msg_flags |= MSG_ZEROCOPY;
+ }
+
+ return msg_flags;
+}
+
+namespace {
+static AsyncSocket::SendMsgParamsCallback defaultSendMsgParamsCallback;
+}
+
AsyncSocket::AsyncSocket()
- : eventBase_(nullptr)
- , writeTimeout_(this, nullptr)
- , ioHandler_(this, nullptr)
- , immediateReadHandler_(this) {
+ : eventBase_(nullptr),
+ writeTimeout_(this, nullptr),
+ ioHandler_(this, nullptr),
+ immediateReadHandler_(this) {
VLOG(5) << "new AsyncSocket()";
init();
}
AsyncSocket::AsyncSocket(EventBase* evb)
- : eventBase_(evb)
- , writeTimeout_(this, evb)
- , ioHandler_(this, evb)
- , immediateReadHandler_(this) {
+ : eventBase_(evb),
+ writeTimeout_(this, evb),
+ ioHandler_(this, evb),
+ immediateReadHandler_(this) {
VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ")";
init();
}
}
AsyncSocket::AsyncSocket(EventBase* evb, int fd)
- : eventBase_(evb)
- , writeTimeout_(this, evb)
- , ioHandler_(this, evb, fd)
- , immediateReadHandler_(this) {
+ : eventBase_(evb),
+ writeTimeout_(this, evb),
+ ioHandler_(this, evb, fd),
+ immediateReadHandler_(this) {
VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd="
<< fd << ")";
init();
state_ = StateEnum::ESTABLISHED;
}
+AsyncSocket::AsyncSocket(AsyncSocket::UniquePtr oldAsyncSocket)
+ : AsyncSocket(oldAsyncSocket->getEventBase(), oldAsyncSocket->detachFd()) {
+ preReceivedData_ = std::move(oldAsyncSocket->preReceivedData_);
+}
+
// init() method, since constructor forwarding isn't supported in most
// compilers yet.
void AsyncSocket::init() {
- assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
+ if (eventBase_) {
+ eventBase_->dcheckIsInEventBaseThread();
+ }
shutdownFlags_ = 0;
state_ = StateEnum::UNINIT;
eventFlags_ = EventHandler::NONE;
sendTimeout_ = 0;
maxReadsPerEvent_ = 16;
connectCallback_ = nullptr;
+ errMessageCallback_ = nullptr;
readCallback_ = nullptr;
writeReqHead_ = nullptr;
writeReqTail_ = nullptr;
shutdownSocketSet_ = nullptr;
appBytesWritten_ = 0;
appBytesReceived_ = 0;
+ sendMsgParamCallback_ = &defaultSendMsgParamsCallback;
}
AsyncSocket::~AsyncSocket() {
void AsyncSocket::setCloseOnExec() {
int rv = fcntl(fd_, F_SETFD, FD_CLOEXEC);
if (rv != 0) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- withAddr("failed to set close-on-exec flag"),
- errno);
+ auto errnoCopy = errno;
+ throw AsyncSocketException(
+ AsyncSocketException::INTERNAL_ERROR,
+ withAddr("failed to set close-on-exec flag"),
+ errnoCopy);
}
}
const OptionMap &options,
const folly::SocketAddress& bindAddr) noexcept {
DestructorGuard dg(this);
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
addr_ = address;
return invalidState(callback);
}
+ connectTimeout_ = std::chrono::milliseconds(timeout);
connectStartTime_ = std::chrono::steady_clock::now();
// Make connect end time at least >= connectStartTime.
connectEndTime_ = connectStartTime_;
// constant (PF_xxx) rather than an address family (AF_xxx), but the
// distinction is mainly just historical. In pretty much all
// implementations the PF_foo and AF_foo constants are identical.
- fd_ = socket(address.getFamily(), SOCK_STREAM, 0);
+ fd_ = fsp::socket(address.getFamily(), SOCK_STREAM, 0);
if (fd_ < 0) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- withAddr("failed to create socket"), errno);
+ auto errnoCopy = errno;
+ throw AsyncSocketException(
+ AsyncSocketException::INTERNAL_ERROR,
+ withAddr("failed to create socket"),
+ errnoCopy);
}
if (shutdownSocketSet_) {
shutdownSocketSet_->add(fd_);
// Put the socket in non-blocking mode
int flags = fcntl(fd_, F_GETFL, 0);
if (flags == -1) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- withAddr("failed to get socket flags"), errno);
+ auto errnoCopy = errno;
+ throw AsyncSocketException(
+ AsyncSocketException::INTERNAL_ERROR,
+ withAddr("failed to get socket flags"),
+ errnoCopy);
}
int rv = fcntl(fd_, F_SETFL, flags | O_NONBLOCK);
if (rv == -1) {
+ auto errnoCopy = errno;
throw AsyncSocketException(
AsyncSocketException::INTERNAL_ERROR,
withAddr("failed to put socket in non-blocking mode"),
- errno);
+ errnoCopy);
}
#if !defined(MSG_NOSIGNAL) && defined(F_SETNOSIGPIPE)
// iOS and OS X don't support MSG_NOSIGNAL; set F_SETNOSIGPIPE instead
rv = fcntl(fd_, F_SETNOSIGPIPE, 1);
if (rv == -1) {
+ auto errnoCopy = errno;
throw AsyncSocketException(
AsyncSocketException::INTERNAL_ERROR,
"failed to enable F_SETNOSIGPIPE on socket",
- errno);
+ errnoCopy);
}
#endif
// By default, turn on TCP_NODELAY
// If setNoDelay() fails, we continue anyway; this isn't a fatal error.
// setNoDelay() will log an error message if it fails.
+ // Also set the cached zeroCopyVal_ since it cannot be set earlier if the fd
+ // is not created
if (address.getFamily() != AF_UNIX) {
(void)setNoDelay(true);
+ setZeroCopy(zeroCopyVal_);
}
VLOG(5) << "AsyncSocket::connect(this=" << this << ", evb=" << eventBase_
// bind the socket
if (bindAddr != anyAddress()) {
int one = 1;
- if (::setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) {
+ if (setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) {
+ auto errnoCopy = errno;
doClose();
throw AsyncSocketException(
- AsyncSocketException::NOT_OPEN,
- "failed to setsockopt prior to bind on " + bindAddr.describe(),
- errno);
+ AsyncSocketException::NOT_OPEN,
+ "failed to setsockopt prior to bind on " + bindAddr.describe(),
+ errnoCopy);
}
bindAddr.getAddress(&addrStorage);
- if (::bind(fd_, saddr, bindAddr.getActualSize()) != 0) {
+ if (bind(fd_, saddr, bindAddr.getActualSize()) != 0) {
+ auto errnoCopy = errno;
doClose();
- throw AsyncSocketException(AsyncSocketException::NOT_OPEN,
- "failed to bind to async socket: " +
- bindAddr.describe(),
- errno);
+ throw AsyncSocketException(
+ AsyncSocketException::NOT_OPEN,
+ "failed to bind to async socket: " + bindAddr.describe(),
+ errnoCopy);
}
}
// Apply the additional options if any.
for (const auto& opt: options) {
- int rv = opt.first.apply(fd_, opt.second);
+ rv = opt.first.apply(fd_, opt.second);
if (rv != 0) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- withAddr("failed to set socket option"),
- errno);
+ auto errnoCopy = errno;
+ throw AsyncSocketException(
+ AsyncSocketException::INTERNAL_ERROR,
+ withAddr("failed to set socket option"),
+ errnoCopy);
}
}
// Perform the connect()
address.getAddress(&addrStorage);
- rv = ::connect(fd_, saddr, address.getActualSize());
- if (rv < 0) {
- if (errno == EINPROGRESS) {
- // Connection in progress.
- if (timeout > 0) {
- // Start a timer in case the connection takes too long.
- if (!writeTimeout_.scheduleTimeout(timeout)) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- withAddr("failed to schedule AsyncSocket connect timeout"));
- }
- }
-
- // Register for write events, so we'll
- // be notified when the connection finishes/fails.
- // Note that we don't register for a persistent event here.
- assert(eventFlags_ == EventHandler::NONE);
- eventFlags_ = EventHandler::WRITE;
- if (!ioHandler_.registerHandler(eventFlags_)) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- withAddr("failed to register AsyncSocket connect handler"));
- }
+ if (tfoEnabled_) {
+ state_ = StateEnum::FAST_OPEN;
+ tfoAttempted_ = true;
+ } else {
+ if (socketConnect(saddr, addr_.getActualSize()) < 0) {
return;
- } else {
- throw AsyncSocketException(AsyncSocketException::NOT_OPEN,
- "connect failed (immediately)", errno);
}
}
// The read callback may not have been set yet, and no writes may be pending
// yet, so we don't have to register for any events at the moment.
VLOG(8) << "AsyncSocket::connect succeeded immediately; this=" << this;
+ assert(errMessageCallback_ == nullptr);
assert(readCallback_ == nullptr);
assert(writeReqHead_ == nullptr);
- state_ = StateEnum::ESTABLISHED;
+ if (state_ != StateEnum::FAST_OPEN) {
+ state_ = StateEnum::ESTABLISHED;
+ }
invokeConnectSuccess();
}
+int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
+#if __linux__
+ if (noTransparentTls_) {
+ // Ignore return value, errors are ok
+ setsockopt(fd_, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0);
+ }
+ if (noTSocks_) {
+ VLOG(4) << "Disabling TSOCKS for fd " << fd_;
+ // Ignore return value, errors are ok
+ setsockopt(fd_, SOL_SOCKET, SO_NO_TSOCKS, nullptr, 0);
+ }
+#endif
+ int rv = fsp::connect(fd_, saddr, len);
+ if (rv < 0) {
+ auto errnoCopy = errno;
+ if (errnoCopy == EINPROGRESS) {
+ scheduleConnectTimeout();
+ registerForConnectEvents();
+ } else {
+ throw AsyncSocketException(
+ AsyncSocketException::NOT_OPEN,
+ "connect failed (immediately)",
+ errnoCopy);
+ }
+ }
+ return rv;
+}
+
+void AsyncSocket::scheduleConnectTimeout() {
+ // Connection in progress.
+ auto timeout = connectTimeout_.count();
+ if (timeout > 0) {
+ // Start a timer in case the connection takes too long.
+ if (!writeTimeout_.scheduleTimeout(uint32_t(timeout))) {
+ throw AsyncSocketException(
+ AsyncSocketException::INTERNAL_ERROR,
+ withAddr("failed to schedule AsyncSocket connect timeout"));
+ }
+ }
+}
+
+void AsyncSocket::registerForConnectEvents() {
+ // Register for write events, so we'll
+ // be notified when the connection finishes/fails.
+ // Note that we don't register for a persistent event here.
+ assert(eventFlags_ == EventHandler::NONE);
+ eventFlags_ = EventHandler::WRITE;
+ if (!ioHandler_.registerHandler(eventFlags_)) {
+ throw AsyncSocketException(
+ AsyncSocketException::INTERNAL_ERROR,
+ withAddr("failed to register AsyncSocket connect handler"));
+ }
+}
+
void AsyncSocket::connect(ConnectCallback* callback,
const string& ip, uint16_t port,
int timeout,
void AsyncSocket::cancelConnect() {
connectCallback_ = nullptr;
- if (state_ == StateEnum::CONNECTING) {
+ if (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN) {
closeNow();
}
}
void AsyncSocket::setSendTimeout(uint32_t milliseconds) {
sendTimeout_ = milliseconds;
- assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
+ if (eventBase_) {
+ eventBase_->dcheckIsInEventBaseThread();
+ }
// If we are currently pending on write requests, immediately update
// writeTimeout_ with the new value.
if ((eventFlags_ & EventHandler::WRITE) &&
- (state_ != StateEnum::CONNECTING)) {
+ (state_ != StateEnum::CONNECTING && state_ != StateEnum::FAST_OPEN)) {
assert(state_ == StateEnum::ESTABLISHED);
assert((shutdownFlags_ & SHUT_WRITE) == 0);
if (sendTimeout_ > 0) {
}
}
+void AsyncSocket::setErrMessageCB(ErrMessageCallback* callback) {
+ VLOG(6) << "AsyncSocket::setErrMessageCB() this=" << this
+ << ", fd=" << fd_ << ", callback=" << callback
+ << ", state=" << state_;
+
+ // Short circuit if callback is the same as the existing errMessageCallback_.
+ if (callback == errMessageCallback_) {
+ return;
+ }
+
+ if (!msgErrQueueSupported) {
+ // Per-socket error message queue is not supported on this platform.
+ return invalidState(callback);
+ }
+
+ DestructorGuard dg(this);
+ eventBase_->dcheckIsInEventBaseThread();
+
+ if (callback == nullptr) {
+ // We should be able to reset the callback regardless of the
+ // socket state. It's important to have a reliable callback
+ // cancellation mechanism.
+ errMessageCallback_ = callback;
+ return;
+ }
+
+ switch ((StateEnum)state_) {
+ case StateEnum::CONNECTING:
+ case StateEnum::FAST_OPEN:
+ case StateEnum::ESTABLISHED: {
+ errMessageCallback_ = callback;
+ return;
+ }
+ case StateEnum::CLOSED:
+ case StateEnum::ERROR:
+ // We should never reach here. SHUT_READ should always be set
+ // if we are in STATE_CLOSED or STATE_ERROR.
+ assert(false);
+ return invalidState(callback);
+ case StateEnum::UNINIT:
+ // We do not allow setReadCallback() to be called before we start
+ // connecting.
+ return invalidState(callback);
+ }
+
+ // We don't put a default case in the switch statement, so that the compiler
+ // will warn us to update the switch statement if a new state is added.
+ return invalidState(callback);
+}
+
+AsyncSocket::ErrMessageCallback* AsyncSocket::getErrMessageCallback() const {
+ return errMessageCallback_;
+}
+
+void AsyncSocket::setSendMsgParamCB(SendMsgParamsCallback* callback) {
+ sendMsgParamCallback_ = callback;
+}
+
+AsyncSocket::SendMsgParamsCallback* AsyncSocket::getSendMsgParamsCB() const {
+ return sendMsgParamCallback_;
+}
+
void AsyncSocket::setReadCB(ReadCallback *callback) {
VLOG(6) << "AsyncSocket::setReadCallback() this=" << this << ", fd=" << fd_
<< ", callback=" << callback << ", state=" << state_;
}
DestructorGuard dg(this);
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
switch ((StateEnum)state_) {
case StateEnum::CONNECTING:
+ case StateEnum::FAST_OPEN:
// For convenience, we allow the read callback to be set while we are
// still connecting. We just store the callback for now. Once the
// connection completes we'll register for read events.
return readCallback_;
}
+bool AsyncSocket::setZeroCopy(bool enable) {
+ if (msgErrQueueSupported) {
+ zeroCopyVal_ = enable;
+
+ if (fd_ < 0) {
+ return false;
+ }
+
+ int val = enable ? 1 : 0;
+ int ret = setsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val));
+
+ // if enable == false, set zeroCopyEnabled_ = false regardless
+ // if SO_ZEROCOPY is set or not
+ if (!enable) {
+ zeroCopyEnabled_ = enable;
+ return true;
+ }
+
+ /* if the setsockopt failed, try to see if the socket inherited the flag
+ * since we cannot set SO_ZEROCOPY on a socket s = accept
+ */
+ if (ret) {
+ val = 0;
+ socklen_t optlen = sizeof(val);
+ ret = getsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, &optlen);
+
+ if (!ret) {
+ enable = val ? true : false;
+ }
+ }
+
+ if (!ret) {
+ zeroCopyEnabled_ = enable;
+
+ return true;
+ }
+ }
+
+ return false;
+}
+
+void AsyncSocket::setZeroCopyWriteChainThreshold(size_t threshold) {
+ zeroCopyWriteChainThreshold_ = threshold;
+}
+
+bool AsyncSocket::isZeroCopyRequest(WriteFlags flags) {
+ return (zeroCopyEnabled_ && isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY));
+}
+
+void AsyncSocket::adjustZeroCopyFlags(
+ folly::IOBuf* buf,
+ folly::WriteFlags& flags) {
+ if (zeroCopyEnabled_ && zeroCopyWriteChainThreshold_ && buf) {
+ if (buf->computeChainDataLength() >= zeroCopyWriteChainThreshold_) {
+ flags |= folly::WriteFlags::WRITE_MSG_ZEROCOPY;
+ } else {
+ flags = unSet(flags, folly::WriteFlags::WRITE_MSG_ZEROCOPY);
+ }
+ }
+}
+
+void AsyncSocket::adjustZeroCopyFlags(
+ const iovec* vec,
+ uint32_t count,
+ folly::WriteFlags& flags) {
+ if (zeroCopyEnabled_ && zeroCopyWriteChainThreshold_) {
+ count = std::min<uint32_t>(count, kIovMax);
+ size_t sum = 0;
+ for (uint32_t i = 0; i < count; ++i) {
+ const iovec* v = vec + i;
+ sum += v->iov_len;
+ }
+
+ if (sum >= zeroCopyWriteChainThreshold_) {
+ flags |= folly::WriteFlags::WRITE_MSG_ZEROCOPY;
+ } else {
+ flags = unSet(flags, folly::WriteFlags::WRITE_MSG_ZEROCOPY);
+ }
+ }
+}
+
+void AsyncSocket::addZeroCopyBuff(std::unique_ptr<folly::IOBuf>&& buf) {
+ uint32_t id = getNextZeroCopyBuffId();
+ folly::IOBuf* ptr = buf.get();
+
+ idZeroCopyBufPtrMap_[id] = ptr;
+ auto& p = idZeroCopyBufPtrToBufMap_[ptr];
+ p.first++;
+ CHECK(p.second.get() == nullptr);
+ p.second = std::move(buf);
+}
+
+void AsyncSocket::addZeroCopyBuff(folly::IOBuf* ptr) {
+ uint32_t id = getNextZeroCopyBuffId();
+ idZeroCopyBufPtrMap_[id] = ptr;
+
+ idZeroCopyBufPtrToBufMap_[ptr].first++;
+}
+
+void AsyncSocket::releaseZeroCopyBuff(uint32_t id) {
+ auto iter = idZeroCopyBufPtrMap_.find(id);
+ CHECK(iter != idZeroCopyBufPtrMap_.end());
+ auto ptr = iter->second;
+ auto iter1 = idZeroCopyBufPtrToBufMap_.find(ptr);
+ CHECK(iter1 != idZeroCopyBufPtrToBufMap_.end());
+ if (0 == --iter1->second.first) {
+ idZeroCopyBufPtrToBufMap_.erase(iter1);
+ }
+}
+
+void AsyncSocket::setZeroCopyBuff(std::unique_ptr<folly::IOBuf>&& buf) {
+ folly::IOBuf* ptr = buf.get();
+ auto& p = idZeroCopyBufPtrToBufMap_[ptr];
+ CHECK(p.second.get() == nullptr);
+
+ p.second = std::move(buf);
+}
+
+bool AsyncSocket::containsZeroCopyBuff(folly::IOBuf* ptr) {
+ return (
+ idZeroCopyBufPtrToBufMap_.find(ptr) != idZeroCopyBufPtrToBufMap_.end());
+}
+
+bool AsyncSocket::isZeroCopyMsg(const cmsghdr& cmsg) const {
+#ifdef MSG_ERRQUEUE
+ if (zeroCopyEnabled_ &&
+ ((cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
+ (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR))) {
+ const struct sock_extended_err* serr =
+ reinterpret_cast<const struct sock_extended_err*>(CMSG_DATA(&cmsg));
+ return (
+ (serr->ee_errno == 0) && (serr->ee_origin == SO_EE_ORIGIN_ZEROCOPY));
+ }
+#endif
+ return false;
+}
+
+void AsyncSocket::processZeroCopyMsg(const cmsghdr& cmsg) {
+#ifdef MSG_ERRQUEUE
+ const struct sock_extended_err* serr =
+ reinterpret_cast<const struct sock_extended_err*>(CMSG_DATA(&cmsg));
+ uint32_t hi = serr->ee_data;
+ uint32_t lo = serr->ee_info;
+
+ for (uint32_t i = lo; i <= hi; i++) {
+ releaseZeroCopyBuff(i);
+ }
+#endif
+}
+
void AsyncSocket::write(WriteCallback* callback,
const void* buf, size_t bytes, WriteFlags flags) {
iovec op;
void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr<IOBuf>&& buf,
WriteFlags flags) {
+ adjustZeroCopyFlags(buf.get(), flags);
+
constexpr size_t kSmallSizeMax = 64;
size_t count = buf->countChainElements();
if (count <= kSmallSizeMax) {
+ // suppress "warning: variable length array 'vec' is used [-Wvla]"
+ FOLLY_PUSH_WARNING
+ FOLLY_GCC_DISABLE_WARNING("-Wvla")
iovec vec[BOOST_PP_IF(FOLLY_HAVE_VLA, count, kSmallSizeMax)];
+ FOLLY_POP_WARNING
+
writeChainImpl(callback, vec, count, std::move(buf), flags);
} else {
iovec* vec = new iovec[count];
<< ", state=" << state_;
DestructorGuard dg(this);
unique_ptr<IOBuf>ioBuf(std::move(buf));
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
if (shutdownFlags_ & (SHUT_WRITE | SHUT_WRITE_PENDING)) {
// No new writes may be performed after the write side of the socket has
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
- int bytesWritten = 0;
+ ssize_t bytesWritten = 0;
bool mustRegister = false;
- if (state_ == StateEnum::ESTABLISHED && !connecting()) {
+ if ((state_ == StateEnum::ESTABLISHED || state_ == StateEnum::FAST_OPEN) &&
+ !connecting()) {
if (writeReqHead_ == nullptr) {
// If we are established and there are no other writes pending,
// we can attempt to perform the write immediately.
assert(writeReqTail_ == nullptr);
assert((eventFlags_ & EventHandler::WRITE) == 0);
- bytesWritten = performWrite(vec, count, flags,
- &countWritten, &partialWritten);
+ auto writeResult = performWrite(
+ vec, uint32_t(count), flags, &countWritten, &partialWritten);
+ bytesWritten = writeResult.writeReturn;
if (bytesWritten < 0) {
- AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
- withAddr("writev failed"), errno);
+ auto errnoCopy = errno;
+ if (writeResult.exception) {
+ return failWrite(__func__, callback, 0, *writeResult.exception);
+ }
+ AsyncSocketException ex(
+ AsyncSocketException::INTERNAL_ERROR,
+ withAddr("writev failed"),
+ errnoCopy);
return failWrite(__func__, callback, 0, ex);
} else if (countWritten == count) {
+ // done, add the whole buffer
+ if (isZeroCopyRequest(flags)) {
+ addZeroCopyBuff(std::move(ioBuf));
+ }
// We successfully wrote everything.
// Invoke the callback and return.
if (callback) {
}
return;
} else { // continue writing the next writeReq
+ // add just the ptr
+ if (isZeroCopyRequest(flags)) {
+ addZeroCopyBuff(ioBuf.get());
+ }
if (bufferCallback_) {
bufferCallback_->onEgressBuffered();
}
}
- mustRegister = true;
+ if (!connecting()) {
+ // Writes might put the socket back into connecting state
+ // if TFO is enabled, and using TFO fails.
+ // This means that write timeouts would not be active, however
+ // connect timeouts would affect this stage.
+ mustRegister = true;
+ }
}
} else if (!connecting()) {
// Invalid state for writing
// Create a new WriteRequest to add to the queue
WriteRequest* req;
try {
- req = BytesWriteRequest::newRequest(this, callback, vec + countWritten,
- count - countWritten, partialWritten,
- bytesWritten, std::move(ioBuf), flags);
+ req = BytesWriteRequest::newRequest(
+ this,
+ callback,
+ vec + countWritten,
+ uint32_t(count - countWritten),
+ partialWritten,
+ uint32_t(bytesWritten),
+ std::move(ioBuf),
+ flags);
} catch (const std::exception& ex) {
// we mainly expect to catch std::bad_alloc here
AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
withAddr(string("failed to append new WriteRequest: ") + ex.what()));
- return failWrite(__func__, callback, bytesWritten, tex);
+ return failWrite(__func__, callback, size_t(bytesWritten), tex);
}
req->consume();
if (writeReqTail_ == nullptr) {
// Declare a DestructorGuard to ensure that the AsyncSocket cannot be
// destroyed until close() returns.
DestructorGuard dg(this);
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
// Since there are write requests pending, we have to set the
// SHUT_WRITE_PENDING flag, and wait to perform the real close until the
<< ", state=" << state_ << ", shutdownFlags="
<< std::hex << (int) shutdownFlags_;
DestructorGuard dg(this);
- assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
+ if (eventBase_) {
+ eventBase_->dcheckIsInEventBaseThread();
+ }
switch (state_) {
case StateEnum::ESTABLISHED:
case StateEnum::CONNECTING:
- {
+ case StateEnum::FAST_OPEN: {
shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
state_ = StateEnum::CLOSED;
return;
}
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
// There are pending writes. Set SHUT_WRITE_PENDING so that the actual
// shutdown will be performed once all writes complete.
}
DestructorGuard dg(this);
- assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
+ if (eventBase_) {
+ eventBase_->dcheckIsInEventBaseThread();
+ }
switch (static_cast<StateEnum>(state_)) {
case StateEnum::ESTABLISHED:
}
// Shutdown writes on the file descriptor
- ::shutdown(fd_, SHUT_WR);
+ shutdown(fd_, SHUT_WR);
// Immediately fail all write requests
failAllWrites(socketShutdownForWritesEx);
// immediately shut down the write side of the socket.
shutdownFlags_ |= SHUT_WRITE_PENDING;
return;
+ case StateEnum::FAST_OPEN:
+ // In fast open state we haven't call connected yet, and if we shutdown
+ // the writes, we will never try to call connect, so shut everything down
+ shutdownFlags_ |= SHUT_WRITE;
+ // Immediately fail all write requests
+ failAllWrites(socketShutdownForWritesEx);
+ return;
case StateEnum::CLOSED:
case StateEnum::ERROR:
// We should never get here. SHUT_WRITE should always be set
return rc == 1;
}
+bool AsyncSocket::writable() const {
+ if (fd_ == -1) {
+ return false;
+ }
+ struct pollfd fds[1];
+ fds[0].fd = fd_;
+ fds[0].events = POLLOUT;
+ fds[0].revents = 0;
+ int rc = poll(fds, 1, 0);
+ return rc == 1;
+}
+
bool AsyncSocket::isPending() const {
return ioHandler_.isPending();
}
}
bool AsyncSocket::good() const {
- return ((state_ == StateEnum::CONNECTING ||
- state_ == StateEnum::ESTABLISHED) &&
- (shutdownFlags_ == 0) && (eventBase_ != nullptr));
+ return (
+ (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN ||
+ state_ == StateEnum::ESTABLISHED) &&
+ (shutdownFlags_ == 0) && (eventBase_ != nullptr));
}
bool AsyncSocket::error() const {
<< ", state=" << state_ << ", events="
<< std::hex << eventFlags_ << ")";
assert(eventBase_ == nullptr);
- assert(eventBase->isInEventBaseThread());
+ eventBase->dcheckIsInEventBaseThread();
eventBase_ = eventBase;
ioHandler_.attachEventBase(eventBase);
writeTimeout_.attachEventBase(eventBase);
+ if (evbChangeCb_) {
+ evbChangeCb_->evbAttached(this);
+ }
}
void AsyncSocket::detachEventBase() {
<< ", old evb=" << eventBase_ << ", state=" << state_
<< ", events=" << std::hex << eventFlags_ << ")";
assert(eventBase_ != nullptr);
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
eventBase_ = nullptr;
ioHandler_.detachEventBase();
writeTimeout_.detachEventBase();
+ if (evbChangeCb_) {
+ evbChangeCb_->evbDetached(this);
+ }
}
bool AsyncSocket::isDetachable() const {
DCHECK(eventBase_ != nullptr);
- DCHECK(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
return !ioHandler_.isHandlerRegistered() && !writeTimeout_.isScheduled();
}
-void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const {
+void AsyncSocket::cacheAddresses() {
+ if (fd_ >= 0) {
+ try {
+ cacheLocalAddress();
+ cachePeerAddress();
+ } catch (const std::system_error& e) {
+ if (e.code() != std::error_code(ENOTCONN, std::system_category())) {
+ VLOG(1) << "Error caching addresses: " << e.code().value() << ", "
+ << e.code().message();
+ }
+ }
+ }
+}
+
+void AsyncSocket::cacheLocalAddress() const {
if (!localAddr_.isInitialized()) {
localAddr_.setFromLocalAddress(fd_);
}
- *address = localAddr_;
}
-void AsyncSocket::getPeerAddress(folly::SocketAddress* address) const {
+void AsyncSocket::cachePeerAddress() const {
if (!addr_.isInitialized()) {
addr_.setFromPeerAddress(fd_);
}
+}
+
+void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const {
+ cacheLocalAddress();
+ *address = localAddr_;
+}
+
+void AsyncSocket::getPeerAddress(folly::SocketAddress* address) const {
+ cachePeerAddress();
*address = addr_;
}
+bool AsyncSocket::getTFOSucceded() const {
+ return detail::tfo_succeeded(fd_);
+}
+
int AsyncSocket::setNoDelay(bool noDelay) {
if (fd_ < 0) {
VLOG(4) << "AsyncSocket::setNoDelay() called on non-open socket "
}
- if (setsockopt(fd_, IPPROTO_TCP, TCP_CONGESTION, cname.c_str(),
- cname.length() + 1) != 0) {
+ if (setsockopt(
+ fd_,
+ IPPROTO_TCP,
+ TCP_CONGESTION,
+ cname.c_str(),
+ socklen_t(cname.length() + 1)) != 0) {
int errnoCopy = errno;
VLOG(2) << "failed to update TCP_CONGESTION option on AsyncSocket "
<< this << "(fd=" << fd_ << ", state=" << state_ << "): "
}
int AsyncSocket::setQuickAck(bool quickack) {
+ (void)quickack;
if (fd_ < 0) {
VLOG(4) << "AsyncSocket::setQuickAck() called on non-open socket "
<< this << "(state=" << state_ << ")";
}
void AsyncSocket::ioReady(uint16_t events) noexcept {
- VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd" << fd_
+ VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd=" << fd_
<< ", events=" << std::hex << events << ", state=" << state_;
DestructorGuard dg(this);
assert(events & EventHandler::READ_WRITE);
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
+
+ uint16_t relevantEvents = uint16_t(events & EventHandler::READ_WRITE);
+ EventBase* originalEventBase = eventBase_;
+ // If we got there it means that either EventHandler::READ or
+ // EventHandler::WRITE is set. Any of these flags can
+ // indicate that there are messages available in the socket
+ // error message queue.
+ handleErrMessages();
+
+ // Return now if handleErrMessages() detached us from our EventBase
+ if (eventBase_ != originalEventBase) {
+ return;
+ }
- uint16_t relevantEvents = events & EventHandler::READ_WRITE;
if (relevantEvents == EventHandler::READ) {
handleRead();
} else if (relevantEvents == EventHandler::WRITE) {
handleWrite();
} else if (relevantEvents == EventHandler::READ_WRITE) {
- EventBase* originalEventBase = eventBase_;
// If both read and write events are ready, process writes first.
handleWrite();
}
}
-ssize_t AsyncSocket::performRead(void** buf,
- size_t* buflen,
- size_t* /* offset */) {
- VLOG(5) << "AsyncSocket::performRead() this=" << this
- << ", buf=" << *buf << ", buflen=" << *buflen;
+AsyncSocket::ReadResult
+AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) {
+ VLOG(5) << "AsyncSocket::performRead() this=" << this << ", buf=" << *buf
+ << ", buflen=" << *buflen;
- int recvFlags = 0;
- if (peek_) {
- recvFlags |= MSG_PEEK;
+ if (preReceivedData_ && !preReceivedData_->empty()) {
+ VLOG(5) << "AsyncSocket::performRead() this=" << this
+ << ", reading pre-received data";
+
+ io::Cursor cursor(preReceivedData_.get());
+ auto len = cursor.pullAtMost(*buf, *buflen);
+
+ IOBufQueue queue;
+ queue.append(std::move(preReceivedData_));
+ queue.trimStart(len);
+ preReceivedData_ = queue.move();
+
+ appBytesReceived_ += len;
+ return ReadResult(len);
}
- ssize_t bytes = recv(fd_, *buf, *buflen, MSG_DONTWAIT | recvFlags);
+ ssize_t bytes = recv(fd_, *buf, *buflen, MSG_DONTWAIT);
if (bytes < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// No more data to read right now.
- return READ_BLOCKING;
+ return ReadResult(READ_BLOCKING);
} else {
- return READ_ERROR;
+ return ReadResult(READ_ERROR);
}
} else {
appBytesReceived_ += bytes;
- return bytes;
+ return ReadResult(bytes);
}
}
-void AsyncSocket::prepareReadBuffer(void** buf, size_t* buflen) noexcept {
+void AsyncSocket::prepareReadBuffer(void** buf, size_t* buflen) {
// no matter what, buffer should be preapared for non-ssl socket
CHECK(readCallback_);
readCallback_->getReadBuffer(buf, buflen);
}
+void AsyncSocket::handleErrMessages() noexcept {
+ // This method has non-empty implementation only for platforms
+ // supporting per-socket error queues.
+ VLOG(5) << "AsyncSocket::handleErrMessages() this=" << this << ", fd=" << fd_
+ << ", state=" << state_;
+ if (errMessageCallback_ == nullptr &&
+ (!zeroCopyEnabled_ || idZeroCopyBufPtrMap_.empty())) {
+ VLOG(7) << "AsyncSocket::handleErrMessages(): "
+ << "no callback installed - exiting.";
+ return;
+ }
+
+#ifdef MSG_ERRQUEUE
+ uint8_t ctrl[1024];
+ unsigned char data;
+ struct msghdr msg;
+ iovec entry;
+
+ entry.iov_base = &data;
+ entry.iov_len = sizeof(data);
+ msg.msg_iov = &entry;
+ msg.msg_iovlen = 1;
+ msg.msg_name = nullptr;
+ msg.msg_namelen = 0;
+ msg.msg_control = ctrl;
+ msg.msg_controllen = sizeof(ctrl);
+ msg.msg_flags = 0;
+
+ int ret;
+ while (true) {
+ ret = recvmsg(fd_, &msg, MSG_ERRQUEUE);
+ VLOG(5) << "AsyncSocket::handleErrMessages(): recvmsg returned " << ret;
+
+ if (ret < 0) {
+ if (errno != EAGAIN) {
+ auto errnoCopy = errno;
+ LOG(ERROR) << "::recvmsg exited with code " << ret
+ << ", errno: " << errnoCopy;
+ AsyncSocketException ex(
+ AsyncSocketException::INTERNAL_ERROR,
+ withAddr("recvmsg() failed"),
+ errnoCopy);
+ failErrMessageRead(__func__, ex);
+ }
+ return;
+ }
+
+ for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ cmsg != nullptr && cmsg->cmsg_len != 0;
+ cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+ if (isZeroCopyMsg(*cmsg)) {
+ processZeroCopyMsg(*cmsg);
+ } else {
+ if (errMessageCallback_) {
+ errMessageCallback_->errMessage(*cmsg);
+ }
+ }
+ }
+ }
+#endif //MSG_ERRQUEUE
+}
+
void AsyncSocket::handleRead() noexcept {
VLOG(5) << "AsyncSocket::handleRead() this=" << this << ", fd=" << fd_
<< ", state=" << state_;
}
// Perform the read
- ssize_t bytesRead = performRead(&buf, &buflen, &offset);
+ auto readResult = performRead(&buf, &buflen, &offset);
+ auto bytesRead = readResult.readReturn;
VLOG(4) << "this=" << this << ", AsyncSocket::handleRead() got "
<< bytesRead << " bytes";
if (bytesRead > 0) {
if (!isBufferMovable_) {
- readCallback_->readDataAvailable(bytesRead);
+ readCallback_->readDataAvailable(size_t(bytesRead));
} else {
CHECK(kOpenSslModeMoveBufferOwnership);
VLOG(5) << "this=" << this << ", AsyncSocket::handleRead() got "
return;
} else if (bytesRead == READ_ERROR) {
readErr_ = READ_ERROR;
- AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
- withAddr("recv() failed"), errno);
+ if (readResult.exception) {
+ return failRead(__func__, *readResult.exception);
+ }
+ auto errnoCopy = errno;
+ AsyncSocketException ex(
+ AsyncSocketException::INTERNAL_ERROR,
+ withAddr("recv() failed"),
+ errnoCopy);
return failRead(__func__, ex);
} else {
assert(bytesRead == READ_EOF);
void AsyncSocket::handleWrite() noexcept {
VLOG(5) << "AsyncSocket::handleWrite() this=" << this << ", fd=" << fd_
<< ", state=" << state_;
+ DestructorGuard dg(this);
+
if (state_ == StateEnum::CONNECTING) {
handleConnect();
return;
// (See the comment in handleRead() explaining how this can happen.)
EventBase* originalEventBase = eventBase_;
while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) {
- if (!writeReqHead_->performWrite()) {
- AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
- withAddr("writev() failed"), errno);
+ auto writeResult = writeReqHead_->performWrite();
+ if (writeResult.writeReturn < 0) {
+ if (writeResult.exception) {
+ return failWrite(__func__, *writeResult.exception);
+ }
+ auto errnoCopy = errno;
+ AsyncSocketException ex(
+ AsyncSocketException::INTERNAL_ERROR,
+ withAddr("writev() failed"),
+ errnoCopy);
return failWrite(__func__, ex);
} else if (writeReqHead_->isComplete()) {
// We finished this request
}
} else {
// Reads are still enabled, so we are only doing a half-shutdown
- ::shutdown(fd_, SHUT_WR);
+ shutdown(fd_, SHUT_WR);
}
}
}
// be a pessimism. In most cases it probably wouldn't be readable, and we
// would just waste an extra system call. Even if it is readable, waiting to
// find out from libevent on the next event loop doesn't seem that bad.
+ //
+ // The exception to this is if we have pre-received data. In that case there
+ // is definitely data available immediately.
+ if (preReceivedData_ && !preReceivedData_->empty()) {
+ handleRead();
+ }
}
void AsyncSocket::handleInitialReadWrite() noexcept {
// one here just to make sure, in case one of our calling code paths ever
// changes.
DestructorGuard dg(this);
-
// If we have a readCallback_, make sure we enable read events. We
// may already be registered for reads if connectSuccess() set
// the read calback.
socklen_t len = sizeof(error);
int rv = getsockopt(fd_, SOL_SOCKET, SO_ERROR, &error, &len);
if (rv != 0) {
- AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
- withAddr("error calling getsockopt() after connect"),
- errno);
+ auto errnoCopy = errno;
+ AsyncSocketException ex(
+ AsyncSocketException::INTERNAL_ERROR,
+ withAddr("error calling getsockopt() after connect"),
+ errnoCopy);
VLOG(4) << "AsyncSocket::handleConnect(this=" << this << ", fd="
<< fd_ << " host=" << addr_.describe()
<< ") exception:" << ex.what();
// are still connecting we just abort the connect rather than waiting for
// it to complete.
assert((shutdownFlags_ & SHUT_READ) == 0);
- ::shutdown(fd_, SHUT_WR);
+ shutdown(fd_, SHUT_WR);
shutdownFlags_ |= SHUT_WRITE;
}
VLOG(7) << "AsyncSocket " << this << ", fd " << fd_ << ": timeout expired: "
<< "state=" << state_ << ", events=" << std::hex << eventFlags_;
DestructorGuard dg(this);
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
if (state_ == StateEnum::CONNECTING) {
// connect() timed out
// Unregister for I/O events.
- AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
- "connect timed out");
- failConnect(__func__, ex);
+ if (connectCallback_) {
+ AsyncSocketException ex(
+ AsyncSocketException::TIMED_OUT,
+ folly::sformat(
+ "connect timed out after {}ms", connectTimeout_.count()));
+ failConnect(__func__, ex);
+ } else {
+ // we faced a connect error without a connect callback, which could
+ // happen due to TFO.
+ AsyncSocketException ex(
+ AsyncSocketException::TIMED_OUT, "write timed out during connection");
+ failWrite(__func__, ex);
+ }
} else {
// a normal write operation timed out
- assert(state_ == StateEnum::ESTABLISHED);
- AsyncSocketException ex(AsyncSocketException::TIMED_OUT, "write timed out");
+ AsyncSocketException ex(
+ AsyncSocketException::TIMED_OUT,
+ folly::sformat("write timed out after {}ms", sendTimeout_));
failWrite(__func__, ex);
}
}
-ssize_t AsyncSocket::performWrite(const iovec* vec,
- uint32_t count,
- WriteFlags flags,
- uint32_t* countWritten,
- uint32_t* partialWritten) {
+ssize_t AsyncSocket::tfoSendMsg(int fd, struct msghdr* msg, int msg_flags) {
+ return detail::tfo_sendmsg(fd, msg, msg_flags);
+}
+
+AsyncSocket::WriteResult
+AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
+ ssize_t totalWritten = 0;
+ if (state_ == StateEnum::FAST_OPEN) {
+ sockaddr_storage addr;
+ auto len = addr_.getAddress(&addr);
+ msg->msg_name = &addr;
+ msg->msg_namelen = len;
+ totalWritten = tfoSendMsg(fd_, msg, msg_flags);
+ if (totalWritten >= 0) {
+ tfoFinished_ = true;
+ state_ = StateEnum::ESTABLISHED;
+ // We schedule this asynchrously so that we don't end up
+ // invoking initial read or write while a write is in progress.
+ scheduleInitialReadWrite();
+ } else if (errno == EINPROGRESS) {
+ VLOG(4) << "TFO falling back to connecting";
+ // A normal sendmsg doesn't return EINPROGRESS, however
+ // TFO might fallback to connecting if there is no
+ // cookie.
+ state_ = StateEnum::CONNECTING;
+ try {
+ scheduleConnectTimeout();
+ registerForConnectEvents();
+ } catch (const AsyncSocketException& ex) {
+ return WriteResult(
+ WRITE_ERROR, std::make_unique<AsyncSocketException>(ex));
+ }
+ // Let's fake it that no bytes were written and return an errno.
+ errno = EAGAIN;
+ totalWritten = -1;
+ } else if (errno == EOPNOTSUPP) {
+ // Try falling back to connecting.
+ VLOG(4) << "TFO not supported";
+ state_ = StateEnum::CONNECTING;
+ try {
+ int ret = socketConnect((const sockaddr*)&addr, len);
+ if (ret == 0) {
+ // connect succeeded immediately
+ // Treat this like no data was written.
+ state_ = StateEnum::ESTABLISHED;
+ scheduleInitialReadWrite();
+ }
+ // If there was no exception during connections,
+ // we would return that no bytes were written.
+ errno = EAGAIN;
+ totalWritten = -1;
+ } catch (const AsyncSocketException& ex) {
+ return WriteResult(
+ WRITE_ERROR, std::make_unique<AsyncSocketException>(ex));
+ }
+ } else if (errno == EAGAIN) {
+ // Normally sendmsg would indicate that the write would block.
+ // However in the fast open case, it would indicate that sendmsg
+ // fell back to a connect. This is a return code from connect()
+ // instead, and is an error condition indicating no fds available.
+ return WriteResult(
+ WRITE_ERROR,
+ std::make_unique<AsyncSocketException>(
+ AsyncSocketException::UNKNOWN, "No more free local ports"));
+ }
+ } else {
+ totalWritten = ::sendmsg(fd, msg, msg_flags);
+ }
+ return WriteResult(totalWritten);
+}
+
+AsyncSocket::WriteResult AsyncSocket::performWrite(
+ const iovec* vec,
+ uint32_t count,
+ WriteFlags flags,
+ uint32_t* countWritten,
+ uint32_t* partialWritten) {
// We use sendmsg() instead of writev() so that we can pass in MSG_NOSIGNAL
// We correctly handle EPIPE errors, so we never want to receive SIGPIPE
// (since it may terminate the program if the main program doesn't explicitly
msg.msg_namelen = 0;
msg.msg_iov = const_cast<iovec *>(vec);
msg.msg_iovlen = std::min<size_t>(count, kIovMax);
- msg.msg_control = nullptr;
- msg.msg_controllen = 0;
msg.msg_flags = 0;
+ msg.msg_controllen = sendMsgParamCallback_->getAncillaryDataSize(flags);
+ CHECK_GE(AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
+ msg.msg_controllen);
- int msg_flags = MSG_DONTWAIT;
-
-#ifdef MSG_NOSIGNAL // Linux-only
- msg_flags |= MSG_NOSIGNAL;
- if (isSet(flags, WriteFlags::CORK)) {
- // MSG_MORE tells the kernel we have more data to send, so wait for us to
- // give it the rest of the data rather than immediately sending a partial
- // frame, even when TCP_NODELAY is enabled.
- msg_flags |= MSG_MORE;
- }
-#endif
- if (isSet(flags, WriteFlags::EOR)) {
- // marks that this is the last byte of a record (response)
- msg_flags |= MSG_EOR;
+ if (msg.msg_controllen != 0) {
+ msg.msg_control = reinterpret_cast<char*>(alloca(msg.msg_controllen));
+ sendMsgParamCallback_->getAncillaryData(flags, msg.msg_control);
+ } else {
+ msg.msg_control = nullptr;
}
- ssize_t totalWritten = ::sendmsg(fd_, &msg, msg_flags);
+ int msg_flags = sendMsgParamCallback_->getFlags(flags, zeroCopyEnabled_);
+
+ auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
+ auto totalWritten = writeResult.writeReturn;
if (totalWritten < 0) {
- if (errno == EAGAIN) {
+ bool tryAgain = (errno == EAGAIN);
+#ifdef __APPLE__
+ // Apple has a bug where doing a second write on a socket which we
+ // have opened with TFO causes an ENOTCONN to be thrown. However the
+ // socket is really connected, so treat ENOTCONN as a EAGAIN until
+ // this bug is fixed.
+ tryAgain |= (errno == ENOTCONN);
+#endif
+ if (!writeResult.exception && tryAgain) {
// TCP buffer is full; we can't write any more data right now.
*countWritten = 0;
*partialWritten = 0;
- return 0;
+ return WriteResult(0);
}
// error
*countWritten = 0;
*partialWritten = 0;
- return -1;
+ return writeResult;
}
appBytesWritten_ += totalWritten;
uint32_t bytesWritten;
uint32_t n;
- for (bytesWritten = totalWritten, n = 0; n < count; ++n) {
+ for (bytesWritten = uint32_t(totalWritten), n = 0; n < count; ++n) {
const iovec* v = vec + n;
if (v->iov_len > bytesWritten) {
// Partial write finished in the middle of this iovec
*countWritten = n;
*partialWritten = bytesWritten;
- return totalWritten;
+ return WriteResult(totalWritten);
}
- bytesWritten -= v->iov_len;
+ bytesWritten -= uint32_t(v->iov_len);
}
assert(bytesWritten == 0);
*countWritten = n;
*partialWritten = 0;
- return totalWritten;
+ return WriteResult(totalWritten);
}
/**
* and call all currently installed callbacks. After an error, the
* AsyncSocket is completely unregistered.
*
- * @return Returns true on succcess, or false on error.
+ * @return Returns true on success, or false on error.
*/
bool AsyncSocket::updateEventRegistration() {
VLOG(5) << "AsyncSocket::updateEventRegistration(this=" << this
<< ", fd=" << fd_ << ", evb=" << eventBase_ << ", state=" << state_
<< ", events=" << std::hex << eventFlags_;
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
if (eventFlags_ == EventHandler::NONE) {
ioHandler_.unregisterHandler();
return true;
// Always register for persistent events, so we don't have to re-register
// after being called back.
- if (!ioHandler_.registerHandler(eventFlags_ | EventHandler::PERSIST)) {
+ if (!ioHandler_.registerHandler(
+ uint16_t(eventFlags_ | EventHandler::PERSIST))) {
eventFlags_ = EventHandler::NONE; // we're not registered after error
AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
withAddr("failed to update AsyncSocket event registration"));
}
}
-void AsyncSocket::finishFail() {
- assert(state_ == StateEnum::ERROR);
- assert(getDestructorGuardCount() > 0);
-
- AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
- withAddr("socket closing after error"));
+void AsyncSocket::invokeAllErrors(const AsyncSocketException& ex) {
invokeConnectErr(ex);
failAllWrites(ex);
}
}
+void AsyncSocket::finishFail() {
+ assert(state_ == StateEnum::ERROR);
+ assert(getDestructorGuardCount() > 0);
+
+ AsyncSocketException ex(
+ AsyncSocketException::INTERNAL_ERROR,
+ withAddr("socket closing after error"));
+ invokeAllErrors(ex);
+}
+
+void AsyncSocket::finishFail(const AsyncSocketException& ex) {
+ assert(state_ == StateEnum::ERROR);
+ assert(getDestructorGuardCount() > 0);
+ invokeAllErrors(ex);
+}
+
void AsyncSocket::fail(const char* fn, const AsyncSocketException& ex) {
VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
<< state_ << " host=" << addr_.describe()
startFail();
invokeConnectErr(ex);
- finishFail();
+ finishFail(ex);
}
void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) {
finishFail();
}
+void AsyncSocket::failErrMessageRead(const char* fn,
+ const AsyncSocketException& ex) {
+ VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
+ << state_ << " host=" << addr_.describe()
+ << "): failed while reading message in " << fn << "(): "
+ << ex.what();
+ startFail();
+
+ if (errMessageCallback_ != nullptr) {
+ ErrMessageCallback* callback = errMessageCallback_;
+ errMessageCallback_ = nullptr;
+ callback->errMessageError(ex);
+ }
+
+ finishFail();
+}
+
void AsyncSocket::failWrite(const char* fn, const AsyncSocketException& ex) {
VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
<< state_ << " host=" << addr_.describe()
void AsyncSocket::invalidState(ConnectCallback* callback) {
VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
- << "): connect() called in invalid state " << state_;
+ << "): connect() called in invalid state " << state_;
/*
* The invalidState() methods don't use the normal failure mechanisms,
}
}
+void AsyncSocket::invalidState(ErrMessageCallback* callback) {
+ VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
+ << "): setErrMessageCB(" << callback
+ << ") called in invalid state " << state_;
+
+ AsyncSocketException ex(
+ AsyncSocketException::NOT_OPEN,
+ msgErrQueueSupported
+ ? "setErrMessageCB() called with socket in invalid state"
+ : "This platform does not support socket error message notifications");
+ if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
+ if (callback) {
+ callback->errMessageError(ex);
+ }
+ } else {
+ startFail();
+ if (callback) {
+ callback->errMessageError(ex);
+ }
+ finishFail();
+ }
+}
+
void AsyncSocket::invokeConnectErr(const AsyncSocketException& ex) {
connectEndTime_ = std::chrono::steady_clock::now();
if (connectCallback_) {
bufferCallback_ = cb;
}
-} // folly
+} // namespace folly