Add SO_ZEROCOPY support
[folly.git] / folly / io / async / AsyncSocket.h
index 973e493cb81f3adcaf1334654cb8d1e645088754..e99300fb238a6ae491f453bdbc886b405037adda 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -31,8 +31,8 @@
 #include <sys/types.h>
 
 #include <chrono>
-#include <memory>
 #include <map>
+#include <memory>
 
 namespace folly {
 
@@ -64,6 +64,14 @@ namespace folly {
  * responding and no further progress can be made sending the data.
  */
 
+#if defined __linux__ && !defined SO_NO_TRANSPARENT_TLS
+#define SO_NO_TRANSPARENT_TLS 200
+#endif
+
+#if defined __linux__ && !defined SO_NO_TSOCKS
+#define SO_NO_TSOCKS 201
+#endif
+
 #ifdef _MSC_VER
 // We do a dynamic_cast on this, in
 // AsyncTransportWrapper::getUnderlyingTransport so be safe and
@@ -94,6 +102,118 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
       noexcept = 0;
   };
 
+  class EvbChangeCallback {
+   public:
+    virtual ~EvbChangeCallback() = default;
+
+    // Called when the socket has been attached to a new EVB
+    // and is called from within that EVB thread
+    virtual void evbAttached(AsyncSocket* socket) = 0;
+
+    // Called when the socket is detached from an EVB and
+    // is called from the EVB thread being detached
+    virtual void evbDetached(AsyncSocket* socket) = 0;
+  };
+
+  /**
+   * This interface is implemented only for platforms supporting
+   * per-socket error queues.
+   */
+  class ErrMessageCallback {
+   public:
+    virtual ~ErrMessageCallback() = default;
+
+    /**
+     * errMessage() will be invoked when kernel puts a message to
+     * the error queue associated with the socket.
+     *
+     * @param cmsg      Reference to cmsghdr structure describing
+     *                  a message read from error queue associated
+     *                  with the socket.
+     */
+    virtual void
+    errMessage(const cmsghdr& cmsg) noexcept = 0;
+
+    /**
+     * errMessageError() will be invoked if an error occurs reading a message
+     * from the socket error stream.
+     *
+     * @param ex        An exception describing the error that occurred.
+     */
+    virtual void errMessageError(const AsyncSocketException& ex) noexcept = 0;
+  };
+
+  class SendMsgParamsCallback {
+   public:
+    virtual ~SendMsgParamsCallback() = default;
+
+    /**
+     * getFlags() will be invoked to retrieve the desired flags to be passed
+     * to ::sendmsg() system call. This method was intentionally declared
+     * non-virtual, so there is no way to override it. Instead feel free to
+     * override getFlagsImpl(flags, defaultFlags) method instead, and enjoy
+     * the convenience of defaultFlags passed there.
+     *
+     * @param flags     Write flags requested for the given write operation
+     */
+    int getFlags(folly::WriteFlags flags, bool zeroCopyEnabled) noexcept {
+      return getFlagsImpl(flags, getDefaultFlags(flags, zeroCopyEnabled));
+    }
+
+    /**
+     * getAncillaryData() will be invoked to initialize ancillary data
+     * buffer referred by "msg_control" field of msghdr structure passed to
+     * ::sendmsg() system call. The function assumes that the size of buffer
+     * is not smaller than the value returned by getAncillaryDataSize() method
+     * for the same combination of flags.
+     *
+     * @param flags     Write flags requested for the given write operation
+     * @param data      Pointer to ancillary data buffer to initialize.
+     */
+    virtual void getAncillaryData(
+      folly::WriteFlags /*flags*/,
+      void* /*data*/) noexcept {}
+
+    /**
+     * getAncillaryDataSize() will be invoked to retrieve the size of
+     * ancillary data buffer which should be passed to ::sendmsg() system call
+     *
+     * @param flags     Write flags requested for the given write operation
+     */
+    virtual uint32_t getAncillaryDataSize(folly::WriteFlags /*flags*/)
+        noexcept {
+      return 0;
+    }
+
+    static const size_t maxAncillaryDataSize{0x5000};
+
+   private:
+    /**
+     * getFlagsImpl() will be invoked by getFlags(folly::WriteFlags flags)
+     * method to retrieve the flags to be passed to ::sendmsg() system call.
+     * SendMsgParamsCallback::getFlags() is calling this method, and returns
+     * its results directly to the caller in AsyncSocket.
+     * Classes inheriting from SendMsgParamsCallback are welcome to override
+     * this method to force SendMsgParamsCallback to return its own set
+     * of flags.
+     *
+     * @param flags        Write flags requested for the given write operation
+     * @param defaultflags A set of message flags returned by getDefaultFlags()
+     *                     method for the given "flags" mask.
+     */
+    virtual int getFlagsImpl(folly::WriteFlags /*flags*/, int defaultFlags) {
+      return defaultFlags;
+    }
+
+    /**
+     * getDefaultFlags() will be invoked by  getFlags(folly::WriteFlags flags)
+     * to retrieve the default set of flags, and pass them to getFlagsImpl(...)
+     *
+     * @param flags     Write flags requested for the given write operation
+     */
+    int getDefaultFlags(folly::WriteFlags flags, bool zeroCopyEnabled) noexcept;
+  };
+
   explicit AsyncSocket();
   /**
    * Create a new unconnected AsyncSocket.
@@ -144,6 +264,14 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
    */
   AsyncSocket(EventBase* evb, int fd);
 
+  /**
+   * Create an AsyncSocket from a different, already connected AsyncSocket.
+   *
+   * Similar to AsyncSocket(evb, fd) when fd was previously owned by an
+   * AsyncSocket.
+   */
+  explicit AsyncSocket(AsyncSocket::UniquePtr);
+
   /**
    * Helper function to create a shared_ptr<AsyncSocket>.
    *
@@ -196,7 +324,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
    * This prevents callers from deleting a AsyncSocket while it is invoking a
    * callback.
    */
-  virtual void destroy() override;
+  void destroy() override;
 
   /**
    * Get the EventBase used by this socket.
@@ -219,6 +347,10 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
    * error.  The AsyncSocket may no longer be used after the file descriptor
    * has been extracted.
    *
+   * This method should be used with care as the resulting fd is not guaranteed
+   * to perfectly reflect the state of the AsyncSocket (security state,
+   * pre-received data, etc.).
+   *
    * Returns the file descriptor.  The caller assumes ownership of the
    * descriptor, and it will not be closed when the AsyncSocket is destroyed.
    */
@@ -336,10 +468,57 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     return maxReadsPerEvent_;
   }
 
+  /**
+   * Set a pointer to ErrMessageCallback implementation which will be
+   * receiving notifications for messages posted to the error queue
+   * associated with the socket.
+   * ErrMessageCallback is implemented only for platforms with
+   * per-socket error message queus support (recvmsg() system call must
+   * )
+   *
+   */
+  virtual void setErrMessageCB(ErrMessageCallback* callback);
+
+  /**
+   * Get a pointer to ErrMessageCallback implementation currently
+   * registered with this socket.
+   *
+   */
+  virtual ErrMessageCallback* getErrMessageCallback() const;
+
+  /**
+   * Set a pointer to SendMsgParamsCallback implementation which
+   * will be used to form ::sendmsg() system call parameters
+   *
+   */
+  virtual void setSendMsgParamCB(SendMsgParamsCallback* callback);
+
+  /**
+   * Get a pointer to SendMsgParamsCallback implementation currently
+   * registered with this socket.
+   *
+   */
+  virtual SendMsgParamsCallback* getSendMsgParamsCB() const;
+
   // Read and write methods
   void setReadCB(ReadCallback* callback) override;
   ReadCallback* getReadCallback() const override;
 
+  static const size_t kDefaultZeroCopyThreshold = 32768; // 32KB
+
+  bool setZeroCopy(bool enable);
+  bool getZeroCopy() const {
+    return zeroCopyEnabled_;
+  }
+
+  void setZeroCopyWriteChainThreshold(size_t threshold);
+  size_t getZeroCopyWriteChainThreshold() const {
+    return zeroCopyWriteChainThreshold_;
+  }
+
+  bool isZeroCopyMsg(const cmsghdr& cmsg) const;
+  void processZeroCopyMsg(const cmsghdr& cmsg);
+
   void write(WriteCallback* callback, const void* buf, size_t bytes,
              WriteFlags flags = WriteFlags::NONE) override;
   void writev(WriteCallback* callback, const iovec* vec, size_t count,
@@ -362,6 +541,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   void shutdownWriteNow() override;
 
   bool readable() const override;
+  bool writable() const override;
   bool isPending() const override;
   virtual bool hangup() const;
   bool good() const override;
@@ -375,9 +555,13 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   void getPeerAddress(
     folly::SocketAddress* address) const override;
 
-  bool isEorTrackingEnabled() const override { return false; }
+  bool isEorTrackingEnabled() const override {
+    return trackEor_;
+  }
 
-  void setEorTracking(bool /*track*/) override {}
+  void setEorTracking(bool track) override {
+    trackEor_ = track;
+  }
 
   bool connecting() const override {
     return (state_ == StateEnum::CONNECTING);
@@ -534,8 +718,51 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     return setsockopt(fd_, level, optname, optval, sizeof(T));
   }
 
-  virtual void setPeek(bool peek) {
-    peek_ = peek;
+  /**
+   * Virtual method for reading a socket option returning integer
+   * value, which is the most typical case. Convenient for overriding
+   * and mocking.
+   *
+   * @param level     same as the "level" parameter in getsockopt().
+   * @param optname   same as the "optname" parameter in getsockopt().
+   * @param optval    same as "optval" parameter in getsockopt().
+   * @param optlen    same as "optlen" parameter in getsockopt().
+   * @return          same as the return value of getsockopt().
+   */
+  virtual int
+  getSockOptVirtual(int level, int optname, void* optval, socklen_t* optlen) {
+    return getsockopt(fd_, level, optname, optval, optlen);
+  }
+
+  /**
+   * Virtual method for setting a socket option accepting integer
+   * value, which is the most typical case. Convenient for overriding
+   * and mocking.
+   *
+   * @param level     same as the "level" parameter in setsockopt().
+   * @param optname   same as the "optname" parameter in setsockopt().
+   * @param optval    same as "optval" parameter in setsockopt().
+   * @param optlen    same as "optlen" parameter in setsockopt().
+   * @return          same as the return value of setsockopt().
+   */
+  virtual int setSockOptVirtual(
+      int level,
+      int optname,
+      void const* optval,
+      socklen_t optlen) {
+    return setsockopt(fd_, level, optname, optval, optlen);
+  }
+
+  /**
+   * Set pre-received data, to be returned to read callback before any data
+   * from the socket.
+   */
+  virtual void setPreReceivedData(std::unique_ptr<IOBuf> data) {
+    if (preReceivedData_) {
+      preReceivedData_->prependChain(std::move(data));
+    } else {
+      preReceivedData_ = std::move(data);
+    }
   }
 
   /**
@@ -549,6 +776,14 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
 #endif
   }
 
+  void disableTransparentTls() {
+    noTransparentTls_ = true;
+  }
+
+  void disableTSocks() {
+    noTSocks_ = true;
+  }
+
   enum class StateEnum : uint8_t {
     UNINIT,
     CONNECTING,
@@ -560,6 +795,19 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
 
   void setBufferCallback(BufferCallback* cb);
 
+  // Callers should set this prior to connecting the socket for the safest
+  // behavior.
+  void setEvbChangedCallback(std::unique_ptr<EvbChangeCallback> cb) {
+    evbChangeCb_ = std::move(cb);
+  }
+
+  /**
+   * Attempt to cache the current local and peer addresses (if not already
+   * cached) so that they are available from getPeerAddress() and
+   * getLocalAddress() even after the socket is closed.
+   */
+  void cacheAddresses();
+
   /**
    * writeReturn is the total number of bytes written, or WRITE_ERROR on error.
    * If no data has been written, 0 is returned.
@@ -668,7 +916,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
    * destroy() instead.  (See the documentation in DelayedDestruction.h for
    * more details.)
    */
-  ~AsyncSocket();
+  ~AsyncSocket() override;
 
   friend std::ostream& operator << (std::ostream& os, const StateEnum& state);
 
@@ -701,7 +949,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
       : AsyncTimeout(eventBase)
       , socket_(socket) {}
 
-    virtual void timeoutExpired() noexcept {
+    void timeoutExpired() noexcept override {
       socket_->timeoutExpired();
     }
 
@@ -718,7 +966,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
       : EventHandler(eventBase, fd)
       , socket_(socket) {}
 
-    virtual void handlerReady(uint16_t events) noexcept {
+    void handlerReady(uint16_t events) noexcept override {
       socket_->ioReady(events);
     }
 
@@ -768,6 +1016,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   virtual void checkForImmediateRead() noexcept;
   virtual void handleInitialReadWrite() noexcept;
   virtual void prepareReadBuffer(void** buf, size_t* buflen);
+  virtual void handleErrMessages() noexcept;
   virtual void handleRead() noexcept;
   virtual void handleWrite() noexcept;
   virtual void handleConnect() noexcept;
@@ -882,6 +1131,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   void fail(const char* fn, const AsyncSocketException& ex);
   void failConnect(const char* fn, const AsyncSocketException& ex);
   void failRead(const char* fn, const AsyncSocketException& ex);
+  void failErrMessageRead(const char* fn, const AsyncSocketException& ex);
   void failWrite(const char* fn, WriteCallback* callback, size_t bytesWritten,
                  const AsyncSocketException& ex);
   void failWrite(const char* fn, const AsyncSocketException& ex);
@@ -889,37 +1139,72 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   virtual void invokeConnectErr(const AsyncSocketException& ex);
   virtual void invokeConnectSuccess();
   void invalidState(ConnectCallback* callback);
+  void invalidState(ErrMessageCallback* callback);
   void invalidState(ReadCallback* callback);
   void invalidState(WriteCallback* callback);
 
   std::string withAddr(const std::string& s);
 
-  StateEnum state_;                     ///< StateEnum describing current state
-  uint8_t shutdownFlags_;               ///< Shutdown state (ShutdownFlags)
-  uint16_t eventFlags_;                 ///< EventBase::HandlerFlags settings
-  int fd_;                              ///< The socket file descriptor
+  void cacheLocalAddress() const;
+  void cachePeerAddress() const;
+
+  bool isZeroCopyRequest(WriteFlags flags);
+  uint32_t getNextZeroCopyBuffId() {
+    return zeroCopyBuffId_++;
+  }
+  void adjustZeroCopyFlags(folly::IOBuf* buf, folly::WriteFlags& flags);
+  void adjustZeroCopyFlags(
+      const iovec* vec,
+      uint32_t count,
+      folly::WriteFlags& flags);
+  void addZeroCopyBuff(std::unique_ptr<folly::IOBuf>&& buf);
+  void addZeroCopyBuff(folly::IOBuf* ptr);
+  void setZeroCopyBuff(std::unique_ptr<folly::IOBuf>&& buf);
+  bool containsZeroCopyBuff(folly::IOBuf* ptr);
+  void releaseZeroCopyBuff(uint32_t id);
+
+  // a folly::IOBuf can be used in multiple partial requests
+  // so we keep a map that maps a buffer id to a raw folly::IOBuf ptr
+  // and one more map that adds a ref count for a folly::IOBuf that is either
+  // the original ptr or nullptr
+  uint32_t zeroCopyBuffId_{0};
+  std::unordered_map<uint32_t, folly::IOBuf*> idZeroCopyBufPtrMap_;
+  std::unordered_map<
+      folly::IOBuf*,
+      std::pair<uint32_t, std::unique_ptr<folly::IOBuf>>>
+      idZeroCopyBufPtrToBufMap_;
+
+  StateEnum state_;                      ///< StateEnum describing current state
+  uint8_t shutdownFlags_;                ///< Shutdown state (ShutdownFlags)
+  uint16_t eventFlags_;                  ///< EventBase::HandlerFlags settings
+  int fd_;                               ///< The socket file descriptor
   mutable folly::SocketAddress addr_;    ///< The address we tried to connect to
   mutable folly::SocketAddress localAddr_;
-                                        ///< The address we are connecting from
-  uint32_t sendTimeout_;                ///< The send timeout, in milliseconds
-  uint16_t maxReadsPerEvent_;           ///< Max reads per event loop iteration
-  EventBase* eventBase_;                ///< The EventBase
-  WriteTimeout writeTimeout_;           ///< A timeout for connect and write
-  IoHandler ioHandler_;                 ///< A EventHandler to monitor the fd
+                                         ///< The address we are connecting from
+  uint32_t sendTimeout_;                 ///< The send timeout, in milliseconds
+  uint16_t maxReadsPerEvent_;            ///< Max reads per event loop iteration
+  EventBase* eventBase_;                 ///< The EventBase
+  WriteTimeout writeTimeout_;            ///< A timeout for connect and write
+  IoHandler ioHandler_;                  ///< A EventHandler to monitor the fd
   ImmediateReadCB immediateReadHandler_; ///< LoopCallback for checking read
 
-  ConnectCallback* connectCallback_;    ///< ConnectCallback
-  ReadCallback* readCallback_;          ///< ReadCallback
-  WriteRequest* writeReqHead_;          ///< Chain of WriteRequests
-  WriteRequest* writeReqTail_;          ///< End of WriteRequest chain
+  ConnectCallback* connectCallback_;     ///< ConnectCallback
+  ErrMessageCallback* errMessageCallback_; ///< TimestampCallback
+  SendMsgParamsCallback* ///< Callback for retrieving
+      sendMsgParamCallback_; ///< ::sendmsg() parameters
+  ReadCallback* readCallback_;           ///< ReadCallback
+  WriteRequest* writeReqHead_;           ///< Chain of WriteRequests
+  WriteRequest* writeReqTail_;           ///< End of WriteRequest chain
   ShutdownSocketSet* shutdownSocketSet_;
-  size_t appBytesReceived_;             ///< Num of bytes received from socket
-  size_t appBytesWritten_;              ///< Num of bytes written to socket
+  size_t appBytesReceived_;              ///< Num of bytes received from socket
+  size_t appBytesWritten_;               ///< Num of bytes written to socket
   bool isBufferMovable_{false};
 
-  bool peek_{false}; // Peek bytes.
+  // Pre-received data, to be returned to read callback before any data from the
+  // socket.
+  std::unique_ptr<IOBuf> preReceivedData_;
 
-  int8_t readErr_{READ_NO_ERROR};      ///< The read error encountered, if any.
+  int8_t readErr_{READ_NO_ERROR};        ///< The read error encountered, if any
 
   std::chrono::steady_clock::time_point connectStartTime_;
   std::chrono::steady_clock::time_point connectEndTime_;
@@ -930,9 +1215,18 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   bool tfoEnabled_{false};
   bool tfoAttempted_{false};
   bool tfoFinished_{false};
+  bool noTransparentTls_{false};
+  bool noTSocks_{false};
+  // Whether to track EOR or not.
+  bool trackEor_{false};
+  bool zeroCopyEnabled_{false};
+  bool zeroCopyVal_{false};
+  size_t zeroCopyWriteChainThreshold_{kDefaultZeroCopyThreshold};
+
+  std::unique_ptr<EvbChangeCallback> evbChangeCb_{nullptr};
 };
 #ifdef _MSC_VER
 #pragma vtordisp(pop)
 #endif
 
-} // folly
+} // namespace folly