Timestamping callback interface in folly::AsyncSocket
authorMaxim Georgiev <maxgeorg@fb.com>
Tue, 14 Feb 2017 01:32:45 +0000 (17:32 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 14 Feb 2017 01:34:41 +0000 (17:34 -0800)
Summary: Adding an interface to folly::AsyncSocket allowing to receive timestamp notifications collected using recvmsg(fd, MSG_ERRQUEUE, ...).

Reviewed By: djwatson

Differential Revision: D4329066

fbshipit-source-id: 154f5c0d04e5c0e410081d48c937af4069479fc2

folly/io/async/AsyncSocket.cpp
folly/io/async/AsyncSocket.h
folly/io/async/test/AsyncSocketTest.h
folly/io/async/test/AsyncSocketTest2.cpp

index 8a9a37e72a80fc03ecae6cf66bc3074bbb2dd38a..d88f9d851deff798f25930e09ff2d7ecd6a44650 100644 (file)
 #include <folly/portability/SysUio.h>
 #include <folly/portability/Unistd.h>
 
+#include <boost/preprocessor/control/if.hpp>
 #include <errno.h>
 #include <limits.h>
-#include <thread>
 #include <sys/types.h>
-#include <boost/preprocessor/control/if.hpp>
+#include <thread>
 
 using std::string;
 using std::unique_ptr;
@@ -38,6 +38,13 @@ namespace fsp = folly::portability::sockets;
 
 namespace folly {
 
+static constexpr bool msgErrQueueSupported =
+#ifdef MSG_ERRQUEUE
+    true;
+#else
+    false;
+#endif // MSG_ERRQUEUE
+
 // static members initializers
 const AsyncSocket::OptionMap AsyncSocket::emptyOptionMap;
 
@@ -233,6 +240,7 @@ void AsyncSocket::init() {
   sendTimeout_ = 0;
   maxReadsPerEvent_ = 16;
   connectCallback_ = nullptr;
+  errMessageCallback_ = nullptr;
   readCallback_ = nullptr;
   writeReqHead_ = nullptr;
   writeReqTail_ = nullptr;
@@ -462,6 +470,7 @@ void AsyncSocket::connect(ConnectCallback* callback,
   // The read callback may not have been set yet, and no writes may be pending
   // yet, so we don't have to register for any events at the moment.
   VLOG(8) << "AsyncSocket::connect succeeded immediately; this=" << this;
+  assert(errMessageCallback_ == nullptr);
   assert(readCallback_ == nullptr);
   assert(writeReqHead_ == nullptr);
   if (state_ != StateEnum::FAST_OPEN) {
@@ -563,6 +572,52 @@ void AsyncSocket::setSendTimeout(uint32_t milliseconds) {
   }
 }
 
+void AsyncSocket::setErrMessageCB(ErrMessageCallback* callback) {
+  VLOG(6) << "AsyncSocket::setErrMessageCB() this=" << this
+          << ", fd=" << fd_ << ", callback=" << callback
+          << ", state=" << state_;
+
+  // Short circuit if callback is the same as the existing timestampCallback_.
+  if (callback == errMessageCallback_) {
+    return;
+  }
+
+  if (!msgErrQueueSupported) {
+      // Per-socket error message queue is not supported on this platform.
+      return invalidState(callback);
+  }
+
+  DestructorGuard dg(this);
+  assert(eventBase_->isInEventBaseThread());
+
+  switch ((StateEnum)state_) {
+    case StateEnum::CONNECTING:
+    case StateEnum::FAST_OPEN:
+    case StateEnum::ESTABLISHED: {
+      errMessageCallback_ = callback;
+      return;
+    }
+    case StateEnum::CLOSED:
+    case StateEnum::ERROR:
+      // We should never reach here.  SHUT_READ should always be set
+      // if we are in STATE_CLOSED or STATE_ERROR.
+      assert(false);
+      return invalidState(callback);
+    case StateEnum::UNINIT:
+      // We do not allow setReadCallback() to be called before we start
+      // connecting.
+      return invalidState(callback);
+  }
+
+  // We don't put a default case in the switch statement, so that the compiler
+  // will warn us to update the switch statement if a new state is added.
+  return invalidState(callback);
+}
+
+AsyncSocket::ErrMessageCallback* AsyncSocket::getErrMessageCallback() const {
+  return errMessageCallback_;
+}
+
 void AsyncSocket::setReadCB(ReadCallback *callback) {
   VLOG(6) << "AsyncSocket::setReadCallback() this=" << this << ", fd=" << fd_
           << ", callback=" << callback << ", state=" << state_;
@@ -1307,12 +1362,23 @@ void AsyncSocket::ioReady(uint16_t events) noexcept {
   assert(eventBase_->isInEventBaseThread());
 
   uint16_t relevantEvents = uint16_t(events & EventHandler::READ_WRITE);
+  EventBase* originalEventBase = eventBase_;
+  // If we got there it means that either EventHandler::READ or
+  // EventHandler::WRITE is set. Any of these flags can
+  // indicate that there are messages available in the socket
+  // error message queue.
+  handleErrMessages();
+
+  // Return now if handleErrMessages() detached us from our EventBase
+  if (eventBase_ != originalEventBase) {
+    return;
+  }
+
   if (relevantEvents == EventHandler::READ) {
     handleRead();
   } else if (relevantEvents == EventHandler::WRITE) {
     handleWrite();
   } else if (relevantEvents == EventHandler::READ_WRITE) {
-    EventBase* originalEventBase = eventBase_;
     // If both read and write events are ready, process writes first.
     handleWrite();
 
@@ -1364,6 +1430,61 @@ void AsyncSocket::prepareReadBuffer(void** buf, size_t* buflen) {
   readCallback_->getReadBuffer(buf, buflen);
 }
 
+void AsyncSocket::handleErrMessages() noexcept {
+  // This method has non-empty implementation only for platforms
+  // supporting per-socket error queues.
+  VLOG(5) << "AsyncSocket::handleErrMessages() this=" << this << ", fd=" << fd_
+          << ", state=" << state_;
+  if (errMessageCallback_ == nullptr) {
+    VLOG(7) << "AsyncSocket::handleErrMessages(): "
+            << "no callback installed - exiting.";
+    return;
+  }
+
+#ifdef MSG_ERRQUEUE
+  uint8_t ctrl[1024];
+  unsigned char data;
+  struct msghdr msg;
+  iovec entry;
+
+  entry.iov_base = &data;
+  entry.iov_len = sizeof(data);
+  msg.msg_iov = &entry;
+  msg.msg_iovlen = 1;
+  msg.msg_name = nullptr;
+  msg.msg_namelen = 0;
+  msg.msg_control = ctrl;
+  msg.msg_controllen = sizeof(ctrl);
+  msg.msg_flags = 0;
+
+  int ret;
+  while (true) {
+    ret = recvmsg(fd_, &msg, MSG_ERRQUEUE);
+    VLOG(5) << "AsyncSocket::handleErrMessages(): recvmsg returned " << ret;
+
+    if (ret < 0) {
+      if (errno != EAGAIN) {
+        auto errnoCopy = errno;
+        LOG(ERROR) << "::recvmsg exited with code " << ret
+                   << ", errno: " << errnoCopy;
+        AsyncSocketException ex(
+          AsyncSocketException::INTERNAL_ERROR,
+          withAddr("recvmsg() failed"),
+          errnoCopy);
+        failErrMessageRead(__func__, ex);
+      }
+      return;
+    }
+
+    for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+         cmsg != nullptr && cmsg->cmsg_len != 0;
+         cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+      errMessageCallback_->errMessage(*cmsg);
+    }
+  }
+#endif //MSG_ERRQUEUE
+}
+
 void AsyncSocket::handleRead() noexcept {
   VLOG(5) << "AsyncSocket::handleRead() this=" << this << ", fd=" << fd_
           << ", state=" << state_;
@@ -2070,6 +2191,23 @@ void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) {
   finishFail();
 }
 
+void AsyncSocket::failErrMessageRead(const char* fn,
+                                     const AsyncSocketException& ex) {
+  VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
+               << state_ << " host=" << addr_.describe()
+               << "): failed while reading message in " << fn << "(): "
+               << ex.what();
+  startFail();
+
+  if (errMessageCallback_ != nullptr) {
+    ErrMessageCallback* callback = errMessageCallback_;
+    errMessageCallback_ = nullptr;
+    callback->errMessageError(ex);
+  }
+
+  finishFail();
+}
+
 void AsyncSocket::failWrite(const char* fn, const AsyncSocketException& ex) {
   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
                << state_ << " host=" << addr_.describe()
@@ -2129,7 +2267,7 @@ void AsyncSocket::failAllWrites(const AsyncSocketException& ex) {
 
 void AsyncSocket::invalidState(ConnectCallback* callback) {
   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
-             << "): connect() called in invalid state " << state_;
+          << "): connect() called in invalid state " << state_;
 
   /*
    * The invalidState() methods don't use the normal failure mechanisms,
@@ -2157,6 +2295,29 @@ void AsyncSocket::invalidState(ConnectCallback* callback) {
   }
 }
 
+void AsyncSocket::invalidState(ErrMessageCallback* callback) {
+  VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
+          << "): setErrMessageCB(" << callback
+          << ") called in invalid state " << state_;
+
+  AsyncSocketException ex(
+      AsyncSocketException::NOT_OPEN,
+      msgErrQueueSupported
+      ? "setErrMessageCB() called with socket in invalid state"
+      : "This platform does not support socket error message notifications");
+  if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
+    if (callback) {
+      callback->errMessageError(ex);
+    }
+  } else {
+    startFail();
+    if (callback) {
+      callback->errMessageError(ex);
+    }
+    finishFail();
+  }
+}
+
 void AsyncSocket::invokeConnectErr(const AsyncSocketException& ex) {
   connectEndTime_ = std::chrono::steady_clock::now();
   if (connectCallback_) {
index 5aeb159c65fb635b8bc872688560ed6712348832..54b37705ae9b5879107485756fea1ed7df5f02e7 100644 (file)
@@ -111,6 +111,34 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     virtual void evbDetached(AsyncSocket* socket) = 0;
   };
 
+  /**
+   * This interface is implemented only for platforms supporting
+   * per-socket error queues.
+   */
+  class ErrMessageCallback {
+   public:
+    virtual ~ErrMessageCallback() = default;
+
+    /**
+     * errMessage() will be invoked when kernel puts a message to
+     * the error queue associated with the socket.
+     *
+     * @param cmsg      Reference to cmsghdr structure describing
+     *                  a message read from error queue associated
+     *                  with the socket.
+     */
+    virtual void
+    errMessage(const cmsghdr& cmsg) noexcept = 0;
+
+    /**
+     * errMessageError() will be invoked if an error occurs reading a message
+     * from the socket error stream.
+     *
+     * @param ex        An exception describing the error that occurred.
+     */
+    virtual void errMessageError(const AsyncSocketException& ex) noexcept = 0;
+  };
+
   explicit AsyncSocket();
   /**
    * Create a new unconnected AsyncSocket.
@@ -353,6 +381,24 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     return maxReadsPerEvent_;
   }
 
+  /**
+   * Set a pointer to ErrMessageCallback implementation which will be
+   * receiving notifications for messages posted to the error queue
+   * associated with the socket.
+   * ErrMessageCallback is implemented only for platforms with
+   * per-socket error message queus support (recvmsg() system call must
+   * )
+   *
+   */
+  void setErrMessageCB(ErrMessageCallback* callback);
+
+  /**
+   * Get a pointer to ErrMessageCallback implementation currently
+   * registered with this socket.
+   *
+   */
+  ErrMessageCallback* getErrMessageCallback() const;
+
   // Read and write methods
   void setReadCB(ReadCallback* callback) override;
   ReadCallback* getReadCallback() const override;
@@ -799,6 +845,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   virtual void checkForImmediateRead() noexcept;
   virtual void handleInitialReadWrite() noexcept;
   virtual void prepareReadBuffer(void** buf, size_t* buflen);
+  virtual void handleErrMessages() noexcept;
   virtual void handleRead() noexcept;
   virtual void handleWrite() noexcept;
   virtual void handleConnect() noexcept;
@@ -913,6 +960,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   void fail(const char* fn, const AsyncSocketException& ex);
   void failConnect(const char* fn, const AsyncSocketException& ex);
   void failRead(const char* fn, const AsyncSocketException& ex);
+  void failErrMessageRead(const char* fn, const AsyncSocketException& ex);
   void failWrite(const char* fn, WriteCallback* callback, size_t bytesWritten,
                  const AsyncSocketException& ex);
   void failWrite(const char* fn, const AsyncSocketException& ex);
@@ -920,37 +968,39 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   virtual void invokeConnectErr(const AsyncSocketException& ex);
   virtual void invokeConnectSuccess();
   void invalidState(ConnectCallback* callback);
+  void invalidState(ErrMessageCallback* callback);
   void invalidState(ReadCallback* callback);
   void invalidState(WriteCallback* callback);
 
   std::string withAddr(const std::string& s);
 
-  StateEnum state_;                     ///< StateEnum describing current state
-  uint8_t shutdownFlags_;               ///< Shutdown state (ShutdownFlags)
-  uint16_t eventFlags_;                 ///< EventBase::HandlerFlags settings
-  int fd_;                              ///< The socket file descriptor
+  StateEnum state_;                      ///< StateEnum describing current state
+  uint8_t shutdownFlags_;                ///< Shutdown state (ShutdownFlags)
+  uint16_t eventFlags_;                  ///< EventBase::HandlerFlags settings
+  int fd_;                               ///< The socket file descriptor
   mutable folly::SocketAddress addr_;    ///< The address we tried to connect to
   mutable folly::SocketAddress localAddr_;
-                                        ///< The address we are connecting from
-  uint32_t sendTimeout_;                ///< The send timeout, in milliseconds
-  uint16_t maxReadsPerEvent_;           ///< Max reads per event loop iteration
-  EventBase* eventBase_;                ///< The EventBase
-  WriteTimeout writeTimeout_;           ///< A timeout for connect and write
-  IoHandler ioHandler_;                 ///< A EventHandler to monitor the fd
+                                         ///< The address we are connecting from
+  uint32_t sendTimeout_;                 ///< The send timeout, in milliseconds
+  uint16_t maxReadsPerEvent_;            ///< Max reads per event loop iteration
+  EventBase* eventBase_;                 ///< The EventBase
+  WriteTimeout writeTimeout_;            ///< A timeout for connect and write
+  IoHandler ioHandler_;                  ///< A EventHandler to monitor the fd
   ImmediateReadCB immediateReadHandler_; ///< LoopCallback for checking read
 
-  ConnectCallback* connectCallback_;    ///< ConnectCallback
-  ReadCallback* readCallback_;          ///< ReadCallback
-  WriteRequest* writeReqHead_;          ///< Chain of WriteRequests
-  WriteRequest* writeReqTail_;          ///< End of WriteRequest chain
+  ConnectCallback* connectCallback_;     ///< ConnectCallback
+  ErrMessageCallback* errMessageCallback_; ///< TimestampCallback
+  ReadCallback* readCallback_;           ///< ReadCallback
+  WriteRequest* writeReqHead_;           ///< Chain of WriteRequests
+  WriteRequest* writeReqTail_;           ///< End of WriteRequest chain
   ShutdownSocketSet* shutdownSocketSet_;
-  size_t appBytesReceived_;             ///< Num of bytes received from socket
-  size_t appBytesWritten_;              ///< Num of bytes written to socket
+  size_t appBytesReceived_;              ///< Num of bytes received from socket
+  size_t appBytesWritten_;               ///< Num of bytes written to socket
   bool isBufferMovable_{false};
 
   bool peek_{false}; // Peek bytes.
 
-  int8_t readErr_{READ_NO_ERROR};      ///< The read error encountered, if any.
+  int8_t readErr_{READ_NO_ERROR};       ///< The read error encountered, if any.
 
   std::chrono::steady_clock::time_point connectStartTime_;
   std::chrono::steady_clock::time_point connectEndTime_;
index 8b9bf95fb52327ffc6488e1bfd23df9bdf18db59..d69c851b7fe1f74ba080df5eb0958404f3192f30 100644 (file)
@@ -202,6 +202,33 @@ class BufferCallback : public folly::AsyncTransport::BufferCallback {
 class ReadVerifier {
 };
 
+class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback {
+ public:
+  TestErrMessageCallback()
+    : exception_(folly::AsyncSocketException::UNKNOWN, "none")
+  {}
+
+  void errMessage(const cmsghdr& cmsg) noexcept override {
+    if (cmsg.cmsg_level == SOL_SOCKET &&
+      cmsg.cmsg_type == SCM_TIMESTAMPING) {
+      gotTimestamp_ = true;
+    } else if (
+      (cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
+      (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR)) {
+      gotByteSeq_ = true;
+    }
+  }
+
+  void errMessageError(
+      const folly::AsyncSocketException& ex) noexcept override {
+    exception_ = ex;
+  }
+
+  folly::AsyncSocketException exception_;
+  bool gotTimestamp_{false};
+  bool gotByteSeq_{false};
+};
+
 class TestServer {
  public:
   // Create a TestServer.
index 5dada1165914bc937660ad659d7e7de86810a20f..bfde20c79ba9270e9156b7e9b793ae6ed7e7c48a 100644 (file)
@@ -2817,4 +2817,95 @@ TEST(AsyncSocketTest, EvbCallbacks) {
   socket->attachEventBase(&evb);
 }
 
+#ifdef MSG_ERRQUEUE
+/* copied from include/uapi/linux/net_tstamp.h */
+/* SO_TIMESTAMPING gets an integer bit field comprised of these values */
+enum SOF_TIMESTAMPING {
+  // SOF_TIMESTAMPING_TX_HARDWARE = (1 << 0),
+  // SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
+  // SOF_TIMESTAMPING_RX_HARDWARE = (1 << 2),
+  // SOF_TIMESTAMPING_RX_SOFTWARE = (1 << 3),
+  SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
+  // SOF_TIMESTAMPING_SYS_HARDWARE = (1 << 5),
+  // SOF_TIMESTAMPING_RAW_HARDWARE = (1 << 6),
+  SOF_TIMESTAMPING_OPT_ID = (1 << 7),
+  SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
+  // SOF_TIMESTAMPING_TX_ACK = (1 << 9),
+  SOF_TIMESTAMPING_OPT_CMSG = (1 << 10),
+  SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
+
+  // SOF_TIMESTAMPING_LAST = SOF_TIMESTAMPING_OPT_TSONLY,
+  // SOF_TIMESTAMPING_MASK = (SOF_TIMESTAMPING_LAST - 1) | SOF_TIMESTAMPING_LAST,
+};
+TEST(AsyncSocketTest, ErrMessageCallback) {
+  TestServer server;
+
+  // connect()
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+  LOG(INFO) << "Client socket fd=" << socket->getFd();
+
+  // Let the socket
+  evb.loop();
+
+  ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
+
+  // Set read callback to keep the socket subscribed for event
+  // notifications. Though we're no planning to read anything from
+  // this side of the connection.
+  ReadCallback rcb(1);
+  socket->setReadCB(&rcb);
+
+  // Set up timestamp callbacks
+  TestErrMessageCallback errMsgCB;
+  socket->setErrMessageCB(&errMsgCB);
+  ASSERT_EQ(socket->getErrMessageCallback(),
+            static_cast<folly::AsyncSocket::ErrMessageCallback*>(&errMsgCB));
+
+  // Enable timestamp notifications
+  ASSERT_GT(socket->getFd(), 0);
+  int flags = SOF_TIMESTAMPING_OPT_ID
+              | SOF_TIMESTAMPING_OPT_TSONLY
+              | SOF_TIMESTAMPING_SOFTWARE
+              | SOF_TIMESTAMPING_OPT_CMSG
+              | SOF_TIMESTAMPING_TX_SCHED;
+  AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
+  EXPECT_EQ(tstampingOpt.apply(socket->getFd(), flags), 0);
+
+  // write()
+  std::vector<uint8_t> wbuf(128, 'a');
+  WriteCallback wcb;
+  socket->write(&wcb, wbuf.data(), wbuf.size());
+
+  // Accept the connection.
+  std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
+  LOG(INFO) << "Server socket fd=" << acceptedSocket->getSocketFD();
+
+  // Loop
+  evb.loopOnce();
+  ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
+
+  // Check that we can read the data that was written to the socket
+  std::vector<uint8_t> rbuf(1 + wbuf.size(), 0);
+  uint32_t bytesRead = acceptedSocket->read(rbuf.data(), rbuf.size());
+  ASSERT_TRUE(std::equal(wbuf.begin(), wbuf.end(), rbuf.begin()));
+  ASSERT_EQ(bytesRead, wbuf.size());
+
+  // Close both sockets
+  acceptedSocket->close();
+  socket->close();
+
+  ASSERT_TRUE(socket->isClosedBySelf());
+  ASSERT_FALSE(socket->isClosedByPeer());
+
+  // Check for the timestamp notifications.
+  ASSERT_EQ(errMsgCB.exception_.type_, folly::AsyncSocketException::UNKNOWN);
+  ASSERT_TRUE(errMsgCB.gotByteSeq_);
+  ASSERT_TRUE(errMsgCB.gotTimestamp_);
+}
+#endif // MSG_ERRQUEUE
+
 #endif