Fix copyright lines
[folly.git] / folly / io / async / test / AsyncSocketTest.h
index d69c851b7fe1f74ba080df5eb0958404f3192f30..fe69a4e4b3a35c57c9ab6a5f0303d39a32d55e59 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2017 Facebook, Inc.
+ * Copyright 2015-present Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@
 #include <folly/portability/Sockets.h>
 
 #include <boost/scoped_array.hpp>
+#include <memory>
 
 enum StateEnum {
   STATE_WAITING,
@@ -82,7 +83,7 @@ class WriteCallback : public folly::AsyncTransportWrapper::WriteCallback {
   }
 
   StateEnum state;
-  size_t bytesWritten;
+  std::atomic<size_t> bytesWritten;
   folly::AsyncSocketException exception;
   VoidCallback successCallback;
   VoidCallback errorCallback;
@@ -96,7 +97,7 @@ class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback {
         buffers(),
         maxBufferSz(_maxBufferSz) {}
 
-  ~ReadCallback() {
+  ~ReadCallback() override {
     for (std::vector<Buffer>::iterator it = buffers.begin();
          it != buffers.end();
          ++it) {
@@ -202,31 +203,62 @@ class BufferCallback : public folly::AsyncTransport::BufferCallback {
 class ReadVerifier {
 };
 
-class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback {
+class TestSendMsgParamsCallback :
+    public folly::AsyncSocket::SendMsgParamsCallback {
  public:
-  TestErrMessageCallback()
-    : exception_(folly::AsyncSocketException::UNKNOWN, "none")
+  TestSendMsgParamsCallback(int flags, uint32_t dataSize, void* data)
+  : flags_(flags),
+    writeFlags_(folly::WriteFlags::NONE),
+    dataSize_(dataSize),
+    data_(data),
+    queriedFlags_(false),
+    queriedData_(false)
   {}
 
-  void errMessage(const cmsghdr& cmsg) noexcept override {
-    if (cmsg.cmsg_level == SOL_SOCKET &&
-      cmsg.cmsg_type == SCM_TIMESTAMPING) {
-      gotTimestamp_ = true;
-    } else if (
-      (cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
-      (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR)) {
-      gotByteSeq_ = true;
+  void reset(int flags) {
+    flags_ = flags;
+    writeFlags_ = folly::WriteFlags::NONE;
+    queriedFlags_ = false;
+    queriedData_ = false;
+  }
+
+  int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
+                                                                  override {
+    queriedFlags_ = true;
+    if (writeFlags_ == folly::WriteFlags::NONE) {
+      writeFlags_ = flags;
+    } else {
+      assert(flags == writeFlags_);
+    }
+    return flags_;
+  }
+
+  void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
+    queriedData_ = true;
+    if (writeFlags_ == folly::WriteFlags::NONE) {
+      writeFlags_ = flags;
+    } else {
+      assert(flags == writeFlags_);
     }
+    assert(data != nullptr);
+    memcpy(data, data_, dataSize_);
   }
 
-  void errMessageError(
-      const folly::AsyncSocketException& ex) noexcept override {
-    exception_ = ex;
+  uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
+    if (writeFlags_ == folly::WriteFlags::NONE) {
+      writeFlags_ = flags;
+    } else {
+      assert(flags == writeFlags_);
+    }
+    return dataSize_;
   }
 
-  folly::AsyncSocketException exception_;
-  bool gotTimestamp_{false};
-  bool gotByteSeq_{false};
+  int flags_;
+  folly::WriteFlags writeFlags_;
+  uint32_t dataSize_;
+  void* data_;
+  bool queriedFlags_;
+  bool queriedData_;
 };
 
 class TestServer {
@@ -339,7 +371,7 @@ class TestServer {
 
   std::shared_ptr<BlockingSocket> accept(int timeout=50) {
     int fd = acceptFD(timeout);
-    return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
+    return std::make_shared<BlockingSocket>(fd);
   }
 
   std::shared_ptr<folly::AsyncSocket> acceptAsync(folly::EventBase* evb,