#include <folly/portability/SysUio.h>
#include <folly/portability/Unistd.h>
+#include <boost/preprocessor/control/if.hpp>
#include <errno.h>
#include <limits.h>
-#include <thread>
#include <sys/types.h>
-#include <boost/preprocessor/control/if.hpp>
+#include <thread>
using std::string;
using std::unique_ptr;
namespace folly {
+static constexpr bool msgErrQueueSupported =
+#ifdef MSG_ERRQUEUE
+ true;
+#else
+ false;
+#endif // MSG_ERRQUEUE
+
// static members initializers
const AsyncSocket::OptionMap AsyncSocket::emptyOptionMap;
sendTimeout_ = 0;
maxReadsPerEvent_ = 16;
connectCallback_ = nullptr;
+ errMessageCallback_ = nullptr;
readCallback_ = nullptr;
writeReqHead_ = nullptr;
writeReqTail_ = nullptr;
// The read callback may not have been set yet, and no writes may be pending
// yet, so we don't have to register for any events at the moment.
VLOG(8) << "AsyncSocket::connect succeeded immediately; this=" << this;
+ assert(errMessageCallback_ == nullptr);
assert(readCallback_ == nullptr);
assert(writeReqHead_ == nullptr);
if (state_ != StateEnum::FAST_OPEN) {
}
}
+void AsyncSocket::setErrMessageCB(ErrMessageCallback* callback) {
+ VLOG(6) << "AsyncSocket::setErrMessageCB() this=" << this
+ << ", fd=" << fd_ << ", callback=" << callback
+ << ", state=" << state_;
+
+ // Short circuit if callback is the same as the existing timestampCallback_.
+ if (callback == errMessageCallback_) {
+ return;
+ }
+
+ if (!msgErrQueueSupported) {
+ // Per-socket error message queue is not supported on this platform.
+ return invalidState(callback);
+ }
+
+ DestructorGuard dg(this);
+ assert(eventBase_->isInEventBaseThread());
+
+ switch ((StateEnum)state_) {
+ case StateEnum::CONNECTING:
+ case StateEnum::FAST_OPEN:
+ case StateEnum::ESTABLISHED: {
+ errMessageCallback_ = callback;
+ return;
+ }
+ case StateEnum::CLOSED:
+ case StateEnum::ERROR:
+ // We should never reach here. SHUT_READ should always be set
+ // if we are in STATE_CLOSED or STATE_ERROR.
+ assert(false);
+ return invalidState(callback);
+ case StateEnum::UNINIT:
+ // We do not allow setReadCallback() to be called before we start
+ // connecting.
+ return invalidState(callback);
+ }
+
+ // We don't put a default case in the switch statement, so that the compiler
+ // will warn us to update the switch statement if a new state is added.
+ return invalidState(callback);
+}
+
+AsyncSocket::ErrMessageCallback* AsyncSocket::getErrMessageCallback() const {
+ return errMessageCallback_;
+}
+
void AsyncSocket::setReadCB(ReadCallback *callback) {
VLOG(6) << "AsyncSocket::setReadCallback() this=" << this << ", fd=" << fd_
<< ", callback=" << callback << ", state=" << state_;
assert(eventBase_->isInEventBaseThread());
uint16_t relevantEvents = uint16_t(events & EventHandler::READ_WRITE);
+ EventBase* originalEventBase = eventBase_;
+ // If we got there it means that either EventHandler::READ or
+ // EventHandler::WRITE is set. Any of these flags can
+ // indicate that there are messages available in the socket
+ // error message queue.
+ handleErrMessages();
+
+ // Return now if handleErrMessages() detached us from our EventBase
+ if (eventBase_ != originalEventBase) {
+ return;
+ }
+
if (relevantEvents == EventHandler::READ) {
handleRead();
} else if (relevantEvents == EventHandler::WRITE) {
handleWrite();
} else if (relevantEvents == EventHandler::READ_WRITE) {
- EventBase* originalEventBase = eventBase_;
// If both read and write events are ready, process writes first.
handleWrite();
readCallback_->getReadBuffer(buf, buflen);
}
+void AsyncSocket::handleErrMessages() noexcept {
+ // This method has non-empty implementation only for platforms
+ // supporting per-socket error queues.
+ VLOG(5) << "AsyncSocket::handleErrMessages() this=" << this << ", fd=" << fd_
+ << ", state=" << state_;
+ if (errMessageCallback_ == nullptr) {
+ VLOG(7) << "AsyncSocket::handleErrMessages(): "
+ << "no callback installed - exiting.";
+ return;
+ }
+
+#ifdef MSG_ERRQUEUE
+ uint8_t ctrl[1024];
+ unsigned char data;
+ struct msghdr msg;
+ iovec entry;
+
+ entry.iov_base = &data;
+ entry.iov_len = sizeof(data);
+ msg.msg_iov = &entry;
+ msg.msg_iovlen = 1;
+ msg.msg_name = nullptr;
+ msg.msg_namelen = 0;
+ msg.msg_control = ctrl;
+ msg.msg_controllen = sizeof(ctrl);
+ msg.msg_flags = 0;
+
+ int ret;
+ while (true) {
+ ret = recvmsg(fd_, &msg, MSG_ERRQUEUE);
+ VLOG(5) << "AsyncSocket::handleErrMessages(): recvmsg returned " << ret;
+
+ if (ret < 0) {
+ if (errno != EAGAIN) {
+ auto errnoCopy = errno;
+ LOG(ERROR) << "::recvmsg exited with code " << ret
+ << ", errno: " << errnoCopy;
+ AsyncSocketException ex(
+ AsyncSocketException::INTERNAL_ERROR,
+ withAddr("recvmsg() failed"),
+ errnoCopy);
+ failErrMessageRead(__func__, ex);
+ }
+ return;
+ }
+
+ for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ cmsg != nullptr && cmsg->cmsg_len != 0;
+ cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+ errMessageCallback_->errMessage(*cmsg);
+ }
+ }
+#endif //MSG_ERRQUEUE
+}
+
void AsyncSocket::handleRead() noexcept {
VLOG(5) << "AsyncSocket::handleRead() this=" << this << ", fd=" << fd_
<< ", state=" << state_;
finishFail();
}
+void AsyncSocket::failErrMessageRead(const char* fn,
+ const AsyncSocketException& ex) {
+ VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
+ << state_ << " host=" << addr_.describe()
+ << "): failed while reading message in " << fn << "(): "
+ << ex.what();
+ startFail();
+
+ if (errMessageCallback_ != nullptr) {
+ ErrMessageCallback* callback = errMessageCallback_;
+ errMessageCallback_ = nullptr;
+ callback->errMessageError(ex);
+ }
+
+ finishFail();
+}
+
void AsyncSocket::failWrite(const char* fn, const AsyncSocketException& ex) {
VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
<< state_ << " host=" << addr_.describe()
void AsyncSocket::invalidState(ConnectCallback* callback) {
VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
- << "): connect() called in invalid state " << state_;
+ << "): connect() called in invalid state " << state_;
/*
* The invalidState() methods don't use the normal failure mechanisms,
}
}
+void AsyncSocket::invalidState(ErrMessageCallback* callback) {
+ VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
+ << "): setErrMessageCB(" << callback
+ << ") called in invalid state " << state_;
+
+ AsyncSocketException ex(
+ AsyncSocketException::NOT_OPEN,
+ msgErrQueueSupported
+ ? "setErrMessageCB() called with socket in invalid state"
+ : "This platform does not support socket error message notifications");
+ if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
+ if (callback) {
+ callback->errMessageError(ex);
+ }
+ } else {
+ startFail();
+ if (callback) {
+ callback->errMessageError(ex);
+ }
+ finishFail();
+ }
+}
+
void AsyncSocket::invokeConnectErr(const AsyncSocketException& ex) {
connectEndTime_ = std::chrono::steady_clock::now();
if (connectCallback_) {
virtual void evbDetached(AsyncSocket* socket) = 0;
};
+ /**
+ * This interface is implemented only for platforms supporting
+ * per-socket error queues.
+ */
+ class ErrMessageCallback {
+ public:
+ virtual ~ErrMessageCallback() = default;
+
+ /**
+ * errMessage() will be invoked when kernel puts a message to
+ * the error queue associated with the socket.
+ *
+ * @param cmsg Reference to cmsghdr structure describing
+ * a message read from error queue associated
+ * with the socket.
+ */
+ virtual void
+ errMessage(const cmsghdr& cmsg) noexcept = 0;
+
+ /**
+ * errMessageError() will be invoked if an error occurs reading a message
+ * from the socket error stream.
+ *
+ * @param ex An exception describing the error that occurred.
+ */
+ virtual void errMessageError(const AsyncSocketException& ex) noexcept = 0;
+ };
+
explicit AsyncSocket();
/**
* Create a new unconnected AsyncSocket.
return maxReadsPerEvent_;
}
+ /**
+ * Set a pointer to ErrMessageCallback implementation which will be
+ * receiving notifications for messages posted to the error queue
+ * associated with the socket.
+ * ErrMessageCallback is implemented only for platforms with
+ * per-socket error message queus support (recvmsg() system call must
+ * )
+ *
+ */
+ void setErrMessageCB(ErrMessageCallback* callback);
+
+ /**
+ * Get a pointer to ErrMessageCallback implementation currently
+ * registered with this socket.
+ *
+ */
+ ErrMessageCallback* getErrMessageCallback() const;
+
// Read and write methods
void setReadCB(ReadCallback* callback) override;
ReadCallback* getReadCallback() const override;
virtual void checkForImmediateRead() noexcept;
virtual void handleInitialReadWrite() noexcept;
virtual void prepareReadBuffer(void** buf, size_t* buflen);
+ virtual void handleErrMessages() noexcept;
virtual void handleRead() noexcept;
virtual void handleWrite() noexcept;
virtual void handleConnect() noexcept;
void fail(const char* fn, const AsyncSocketException& ex);
void failConnect(const char* fn, const AsyncSocketException& ex);
void failRead(const char* fn, const AsyncSocketException& ex);
+ void failErrMessageRead(const char* fn, const AsyncSocketException& ex);
void failWrite(const char* fn, WriteCallback* callback, size_t bytesWritten,
const AsyncSocketException& ex);
void failWrite(const char* fn, const AsyncSocketException& ex);
virtual void invokeConnectErr(const AsyncSocketException& ex);
virtual void invokeConnectSuccess();
void invalidState(ConnectCallback* callback);
+ void invalidState(ErrMessageCallback* callback);
void invalidState(ReadCallback* callback);
void invalidState(WriteCallback* callback);
std::string withAddr(const std::string& s);
- StateEnum state_; ///< StateEnum describing current state
- uint8_t shutdownFlags_; ///< Shutdown state (ShutdownFlags)
- uint16_t eventFlags_; ///< EventBase::HandlerFlags settings
- int fd_; ///< The socket file descriptor
+ StateEnum state_; ///< StateEnum describing current state
+ uint8_t shutdownFlags_; ///< Shutdown state (ShutdownFlags)
+ uint16_t eventFlags_; ///< EventBase::HandlerFlags settings
+ int fd_; ///< The socket file descriptor
mutable folly::SocketAddress addr_; ///< The address we tried to connect to
mutable folly::SocketAddress localAddr_;
- ///< The address we are connecting from
- uint32_t sendTimeout_; ///< The send timeout, in milliseconds
- uint16_t maxReadsPerEvent_; ///< Max reads per event loop iteration
- EventBase* eventBase_; ///< The EventBase
- WriteTimeout writeTimeout_; ///< A timeout for connect and write
- IoHandler ioHandler_; ///< A EventHandler to monitor the fd
+ ///< The address we are connecting from
+ uint32_t sendTimeout_; ///< The send timeout, in milliseconds
+ uint16_t maxReadsPerEvent_; ///< Max reads per event loop iteration
+ EventBase* eventBase_; ///< The EventBase
+ WriteTimeout writeTimeout_; ///< A timeout for connect and write
+ IoHandler ioHandler_; ///< A EventHandler to monitor the fd
ImmediateReadCB immediateReadHandler_; ///< LoopCallback for checking read
- ConnectCallback* connectCallback_; ///< ConnectCallback
- ReadCallback* readCallback_; ///< ReadCallback
- WriteRequest* writeReqHead_; ///< Chain of WriteRequests
- WriteRequest* writeReqTail_; ///< End of WriteRequest chain
+ ConnectCallback* connectCallback_; ///< ConnectCallback
+ ErrMessageCallback* errMessageCallback_; ///< TimestampCallback
+ ReadCallback* readCallback_; ///< ReadCallback
+ WriteRequest* writeReqHead_; ///< Chain of WriteRequests
+ WriteRequest* writeReqTail_; ///< End of WriteRequest chain
ShutdownSocketSet* shutdownSocketSet_;
- size_t appBytesReceived_; ///< Num of bytes received from socket
- size_t appBytesWritten_; ///< Num of bytes written to socket
+ size_t appBytesReceived_; ///< Num of bytes received from socket
+ size_t appBytesWritten_; ///< Num of bytes written to socket
bool isBufferMovable_{false};
bool peek_{false}; // Peek bytes.
- int8_t readErr_{READ_NO_ERROR}; ///< The read error encountered, if any.
+ int8_t readErr_{READ_NO_ERROR}; ///< The read error encountered, if any.
std::chrono::steady_clock::time_point connectStartTime_;
std::chrono::steady_clock::time_point connectEndTime_;
class ReadVerifier {
};
+class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback {
+ public:
+ TestErrMessageCallback()
+ : exception_(folly::AsyncSocketException::UNKNOWN, "none")
+ {}
+
+ void errMessage(const cmsghdr& cmsg) noexcept override {
+ if (cmsg.cmsg_level == SOL_SOCKET &&
+ cmsg.cmsg_type == SCM_TIMESTAMPING) {
+ gotTimestamp_ = true;
+ } else if (
+ (cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
+ (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR)) {
+ gotByteSeq_ = true;
+ }
+ }
+
+ void errMessageError(
+ const folly::AsyncSocketException& ex) noexcept override {
+ exception_ = ex;
+ }
+
+ folly::AsyncSocketException exception_;
+ bool gotTimestamp_{false};
+ bool gotByteSeq_{false};
+};
+
class TestServer {
public:
// Create a TestServer.
socket->attachEventBase(&evb);
}
+#ifdef MSG_ERRQUEUE
+/* copied from include/uapi/linux/net_tstamp.h */
+/* SO_TIMESTAMPING gets an integer bit field comprised of these values */
+enum SOF_TIMESTAMPING {
+ // SOF_TIMESTAMPING_TX_HARDWARE = (1 << 0),
+ // SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
+ // SOF_TIMESTAMPING_RX_HARDWARE = (1 << 2),
+ // SOF_TIMESTAMPING_RX_SOFTWARE = (1 << 3),
+ SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
+ // SOF_TIMESTAMPING_SYS_HARDWARE = (1 << 5),
+ // SOF_TIMESTAMPING_RAW_HARDWARE = (1 << 6),
+ SOF_TIMESTAMPING_OPT_ID = (1 << 7),
+ SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
+ // SOF_TIMESTAMPING_TX_ACK = (1 << 9),
+ SOF_TIMESTAMPING_OPT_CMSG = (1 << 10),
+ SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
+
+ // SOF_TIMESTAMPING_LAST = SOF_TIMESTAMPING_OPT_TSONLY,
+ // SOF_TIMESTAMPING_MASK = (SOF_TIMESTAMPING_LAST - 1) | SOF_TIMESTAMPING_LAST,
+};
+TEST(AsyncSocketTest, ErrMessageCallback) {
+ TestServer server;
+
+ // connect()
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+ LOG(INFO) << "Client socket fd=" << socket->getFd();
+
+ // Let the socket
+ evb.loop();
+
+ ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
+
+ // Set read callback to keep the socket subscribed for event
+ // notifications. Though we're no planning to read anything from
+ // this side of the connection.
+ ReadCallback rcb(1);
+ socket->setReadCB(&rcb);
+
+ // Set up timestamp callbacks
+ TestErrMessageCallback errMsgCB;
+ socket->setErrMessageCB(&errMsgCB);
+ ASSERT_EQ(socket->getErrMessageCallback(),
+ static_cast<folly::AsyncSocket::ErrMessageCallback*>(&errMsgCB));
+
+ // Enable timestamp notifications
+ ASSERT_GT(socket->getFd(), 0);
+ int flags = SOF_TIMESTAMPING_OPT_ID
+ | SOF_TIMESTAMPING_OPT_TSONLY
+ | SOF_TIMESTAMPING_SOFTWARE
+ | SOF_TIMESTAMPING_OPT_CMSG
+ | SOF_TIMESTAMPING_TX_SCHED;
+ AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
+ EXPECT_EQ(tstampingOpt.apply(socket->getFd(), flags), 0);
+
+ // write()
+ std::vector<uint8_t> wbuf(128, 'a');
+ WriteCallback wcb;
+ socket->write(&wcb, wbuf.data(), wbuf.size());
+
+ // Accept the connection.
+ std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+ LOG(INFO) << "Server socket fd=" << acceptedSocket->getSocketFD();
+
+ // Loop
+ evb.loopOnce();
+ ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
+
+ // Check that we can read the data that was written to the socket
+ std::vector<uint8_t> rbuf(1 + wbuf.size(), 0);
+ uint32_t bytesRead = acceptedSocket->read(rbuf.data(), rbuf.size());
+ ASSERT_TRUE(std::equal(wbuf.begin(), wbuf.end(), rbuf.begin()));
+ ASSERT_EQ(bytesRead, wbuf.size());
+
+ // Close both sockets
+ acceptedSocket->close();
+ socket->close();
+
+ ASSERT_TRUE(socket->isClosedBySelf());
+ ASSERT_FALSE(socket->isClosedByPeer());
+
+ // Check for the timestamp notifications.
+ ASSERT_EQ(errMsgCB.exception_.type_, folly::AsyncSocketException::UNKNOWN);
+ ASSERT_TRUE(errMsgCB.gotByteSeq_);
+ ASSERT_TRUE(errMsgCB.gotTimestamp_);
+}
+#endif // MSG_ERRQUEUE
+
#endif