Fix copyright lines
[folly.git] / folly / io / async / test / AsyncSocketTest.h
index 3bae6a087d0870642c546f6523ccc780a7634b2b..fe69a4e4b3a35c57c9ab6a5f0303d39a32d55e59 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2015-present Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@
 #include <folly/portability/Sockets.h>
 
 #include <boost/scoped_array.hpp>
+#include <memory>
 
 enum StateEnum {
   STATE_WAITING,
@@ -70,11 +71,11 @@ class WriteCallback : public folly::AsyncTransportWrapper::WriteCallback {
     }
   }
 
-  void writeErr(size_t bytesWritten,
+  void writeErr(size_t nBytesWritten,
                 const folly::AsyncSocketException& ex) noexcept override {
     LOG(ERROR) << ex.what();
     state = STATE_FAILED;
-    this->bytesWritten = bytesWritten;
+    this->bytesWritten = nBytesWritten;
     exception = ex;
     if (errorCallback) {
       errorCallback();
@@ -82,7 +83,7 @@ class WriteCallback : public folly::AsyncTransportWrapper::WriteCallback {
   }
 
   StateEnum state;
-  size_t bytesWritten;
+  std::atomic<size_t> bytesWritten;
   folly::AsyncSocketException exception;
   VoidCallback successCallback;
   VoidCallback errorCallback;
@@ -96,7 +97,7 @@ class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback {
         buffers(),
         maxBufferSz(_maxBufferSz) {}
 
-  ~ReadCallback() {
+  ~ReadCallback() override {
     for (std::vector<Buffer>::iterator it = buffers.begin();
          it != buffers.end();
          ++it) {
@@ -160,10 +161,10 @@ class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback {
       buffer = nullptr;
       length = 0;
     }
-    void allocate(size_t length) {
+    void allocate(size_t len) {
       assert(buffer == nullptr);
-      this->buffer = static_cast<char*>(malloc(length));
-      this->length = length;
+      this->buffer = static_cast<char*>(malloc(len));
+      this->length = len;
     }
     void free() {
       ::free(buffer);
@@ -202,6 +203,64 @@ class BufferCallback : public folly::AsyncTransport::BufferCallback {
 class ReadVerifier {
 };
 
+class TestSendMsgParamsCallback :
+    public folly::AsyncSocket::SendMsgParamsCallback {
+ public:
+  TestSendMsgParamsCallback(int flags, uint32_t dataSize, void* data)
+  : flags_(flags),
+    writeFlags_(folly::WriteFlags::NONE),
+    dataSize_(dataSize),
+    data_(data),
+    queriedFlags_(false),
+    queriedData_(false)
+  {}
+
+  void reset(int flags) {
+    flags_ = flags;
+    writeFlags_ = folly::WriteFlags::NONE;
+    queriedFlags_ = false;
+    queriedData_ = false;
+  }
+
+  int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
+                                                                  override {
+    queriedFlags_ = true;
+    if (writeFlags_ == folly::WriteFlags::NONE) {
+      writeFlags_ = flags;
+    } else {
+      assert(flags == writeFlags_);
+    }
+    return flags_;
+  }
+
+  void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
+    queriedData_ = true;
+    if (writeFlags_ == folly::WriteFlags::NONE) {
+      writeFlags_ = flags;
+    } else {
+      assert(flags == writeFlags_);
+    }
+    assert(data != nullptr);
+    memcpy(data, data_, dataSize_);
+  }
+
+  uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
+    if (writeFlags_ == folly::WriteFlags::NONE) {
+      writeFlags_ = flags;
+    } else {
+      assert(flags == writeFlags_);
+    }
+    return dataSize_;
+  }
+
+  int flags_;
+  folly::WriteFlags writeFlags_;
+  uint32_t dataSize_;
+  void* data_;
+  bool queriedFlags_;
+  bool queriedData_;
+};
+
 class TestServer {
  public:
   // Create a TestServer.
@@ -312,7 +371,7 @@ class TestServer {
 
   std::shared_ptr<BlockingSocket> accept(int timeout=50) {
     int fd = acceptFD(timeout);
-    return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
+    return std::make_shared<BlockingSocket>(fd);
   }
 
   std::shared_ptr<folly::AsyncSocket> acceptAsync(folly::EventBase* evb,