Fix copyright lines
[folly.git] / folly / io / async / AsyncTransport.h
index a1023b2cca4c1c407e5877557e920a1d8b37a0ad..050c22bf9ae719e6a363d0c79f6e35d27b4dbcb4 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2014-present Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
 #include <folly/io/async/AsyncSocketBase.h>
 #include <folly/io/async/DelayedDestruction.h>
 #include <folly/io/async/EventBase.h>
-#include <folly/io/async/ssl/OpenSSLPtrTypes.h>
+#include <folly/portability/OpenSSL.h>
 #include <folly/portability/SysUio.h>
-
-#include <openssl/ssl.h>
+#include <folly/ssl/OpenSSLPtrTypes.h>
 
 constexpr bool kOpenSslModeMoveBufferOwnership =
 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
@@ -61,6 +60,10 @@ enum class WriteFlags : uint32_t {
    * this indicates that only the write side of socket should be shutdown
    */
   WRITE_SHUTDOWN = 0x04,
+  /*
+   * use msg zerocopy if allowed
+   */
+  WRITE_MSG_ZEROCOPY = 0x08,
 };
 
 /*
@@ -71,6 +74,14 @@ inline WriteFlags operator|(WriteFlags a, WriteFlags b) {
     static_cast<uint32_t>(a) | static_cast<uint32_t>(b));
 }
 
+/*
+ * compound assignment union operator
+ */
+inline WriteFlags& operator|=(WriteFlags& a, WriteFlags b) {
+  a = a | b;
+  return a;
+}
+
 /*
  * intersection operator
  */
@@ -79,6 +90,14 @@ inline WriteFlags operator&(WriteFlags a, WriteFlags b) {
     static_cast<uint32_t>(a) & static_cast<uint32_t>(b));
 }
 
+/*
+ * compound assignment intersection operator
+ */
+inline WriteFlags& operator&=(WriteFlags& a, WriteFlags b) {
+  a = a & b;
+  return a;
+}
+
 /*
  * exclusion parameter
  */
@@ -223,6 +242,16 @@ class AsyncTransport : public DelayedDestruction, public AsyncSocketBase {
    */
   virtual bool readable() const = 0;
 
+  /**
+   * Determine if the transport is writable or not.
+   *
+   * @return  true iff the transport is writable, false otherwise.
+   */
+  virtual bool writable() const {
+    // By default return good() - leave it to implementers to override.
+    return good();
+  }
+
   /**
    * Determine if the there is pending data on the transport.
    *
@@ -306,7 +335,21 @@ class AsyncTransport : public DelayedDestruction, public AsyncSocketBase {
    */
   virtual void getLocalAddress(SocketAddress* address) const = 0;
 
-  virtual void getAddress(SocketAddress* address) const {
+  /**
+   * Get the address of the remote endpoint to which this transport is
+   * connected.
+   *
+   * This function may throw AsyncSocketException on error.
+   *
+   * @return         Return the local address
+   */
+  SocketAddress getLocalAddress() const {
+    SocketAddress addr;
+    getLocalAddress(&addr);
+    return addr;
+  }
+
+  void getAddress(SocketAddress* address) const override {
     getLocalAddress(address);
   }
 
@@ -321,6 +364,20 @@ class AsyncTransport : public DelayedDestruction, public AsyncSocketBase {
    */
   virtual void getPeerAddress(SocketAddress* address) const = 0;
 
+  /**
+   * Get the address of the remote endpoint to which this transport is
+   * connected.
+   *
+   * This function may throw AsyncSocketException on error.
+   *
+   * @return         Return the remote endpoint's address
+   */
+  SocketAddress getPeerAddress() const {
+    SocketAddress addr;
+    getPeerAddress(&addr);
+    return addr;
+  }
+
   /**
    * Get the certificate used to authenticate the peer.
    */
@@ -333,6 +390,22 @@ class AsyncTransport : public DelayedDestruction, public AsyncSocketBase {
     return nullptr;
   }
 
+  /**
+   * Return the application protocol being used by the underlying transport
+   * protocol. This is useful for transports which are used to tunnel other
+   * protocols.
+   */
+  virtual std::string getApplicationProtocol() noexcept {
+    return "";
+  }
+
+  /**
+   * Returns the name of the security protocol being used.
+   */
+  virtual std::string getSecurityProtocol() const {
+    return "";
+  }
+
   /**
    * @return True iff end of record tracking is enabled
    */
@@ -385,7 +458,7 @@ class AsyncTransport : public DelayedDestruction, public AsyncSocketBase {
   }
 
  protected:
-  virtual ~AsyncTransport() = default;
+  ~AsyncTransport() override = default;
 };
 
 class AsyncReader {
@@ -493,7 +566,7 @@ class AsyncReader {
      */
 
     virtual void readBufferAvailable(std::unique_ptr<IOBuf> /*readBuf*/)
-      noexcept {};
+      noexcept {}
 
     /**
      * readEOF() will be invoked when the transport is closed.
@@ -577,15 +650,22 @@ class AsyncTransportWrapper : virtual public AsyncTransport,
   // to keep compatibility.
   using ReadCallback    = AsyncReader::ReadCallback;
   using WriteCallback   = AsyncWriter::WriteCallback;
-  virtual void setReadCB(ReadCallback* callback) override = 0;
-  virtual ReadCallback* getReadCallback() const override = 0;
-  virtual void write(WriteCallback* callback, const void* buf, size_t bytes,
-                     WriteFlags flags = WriteFlags::NONE) override = 0;
-  virtual void writev(WriteCallback* callback, const iovec* vec, size_t count,
-                      WriteFlags flags = WriteFlags::NONE) override = 0;
-  virtual void writeChain(WriteCallback* callback,
-                          std::unique_ptr<IOBuf>&& buf,
-                          WriteFlags flags = WriteFlags::NONE) override = 0;
+  void setReadCB(ReadCallback* callback) override = 0;
+  ReadCallback* getReadCallback() const override = 0;
+  void write(
+      WriteCallback* callback,
+      const void* buf,
+      size_t bytes,
+      WriteFlags flags = WriteFlags::NONE) override = 0;
+  void writev(
+      WriteCallback* callback,
+      const iovec* vec,
+      size_t count,
+      WriteFlags flags = WriteFlags::NONE) override = 0;
+  void writeChain(
+      WriteCallback* callback,
+      std::unique_ptr<IOBuf>&& buf,
+      WriteFlags flags = WriteFlags::NONE) override = 0;
   /**
    * The transport wrapper may wrap another transport. This returns the
    * transport that is wrapped. It returns nullptr if there is no wrapped
@@ -618,20 +698,6 @@ class AsyncTransportWrapper : virtual public AsyncTransport,
     return const_cast<T*>(static_cast<const AsyncTransportWrapper*>(this)
         ->getUnderlyingTransport<T>());
   }
-
-  /**
-   * Return the application protocol being used by the underlying transport
-   * protocol. This is useful for transports which are used to tunnel other
-   * protocols.
-   */
-  virtual std::string getApplicationProtocol() noexcept {
-    return "";
-  }
-
-  /**
-   * Returns the name of the security protocol being used.
-   */
-  virtual std::string getSecurityProtocol() const { return ""; }
 };
 
-} // folly
+} // namespace folly