Implementing a callback interface for folly::AsyncSocket allowing to supply an ancill...
[folly.git] / folly / io / async / test / AsyncSocketTest.h
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/io/async/AsyncSocket.h>
19 #include <folly/io/async/test/BlockingSocket.h>
20 #include <folly/portability/Sockets.h>
21
22 #include <boost/scoped_array.hpp>
23
24 enum StateEnum {
25   STATE_WAITING,
26   STATE_SUCCEEDED,
27   STATE_FAILED
28 };
29
30 typedef std::function<void()> VoidCallback;
31
32 class ConnCallback : public folly::AsyncSocket::ConnectCallback {
33  public:
34   ConnCallback()
35       : state(STATE_WAITING),
36         exception(folly::AsyncSocketException::UNKNOWN, "none") {}
37
38   void connectSuccess() noexcept override {
39     state = STATE_SUCCEEDED;
40     if (successCallback) {
41       successCallback();
42     }
43   }
44
45   void connectErr(const folly::AsyncSocketException& ex) noexcept override {
46     state = STATE_FAILED;
47     exception = ex;
48     if (errorCallback) {
49       errorCallback();
50     }
51   }
52
53   StateEnum state;
54   folly::AsyncSocketException exception;
55   VoidCallback successCallback;
56   VoidCallback errorCallback;
57 };
58
59 class WriteCallback : public folly::AsyncTransportWrapper::WriteCallback {
60  public:
61   WriteCallback()
62       : state(STATE_WAITING),
63         bytesWritten(0),
64         exception(folly::AsyncSocketException::UNKNOWN, "none") {}
65
66   void writeSuccess() noexcept override {
67     state = STATE_SUCCEEDED;
68     if (successCallback) {
69       successCallback();
70     }
71   }
72
73   void writeErr(size_t nBytesWritten,
74                 const folly::AsyncSocketException& ex) noexcept override {
75     LOG(ERROR) << ex.what();
76     state = STATE_FAILED;
77     this->bytesWritten = nBytesWritten;
78     exception = ex;
79     if (errorCallback) {
80       errorCallback();
81     }
82   }
83
84   StateEnum state;
85   size_t bytesWritten;
86   folly::AsyncSocketException exception;
87   VoidCallback successCallback;
88   VoidCallback errorCallback;
89 };
90
91 class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback {
92  public:
93   explicit ReadCallback(size_t _maxBufferSz = 4096)
94       : state(STATE_WAITING),
95         exception(folly::AsyncSocketException::UNKNOWN, "none"),
96         buffers(),
97         maxBufferSz(_maxBufferSz) {}
98
99   ~ReadCallback() {
100     for (std::vector<Buffer>::iterator it = buffers.begin();
101          it != buffers.end();
102          ++it) {
103       it->free();
104     }
105     currentBuffer.free();
106   }
107
108   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
109     if (!currentBuffer.buffer) {
110       currentBuffer.allocate(maxBufferSz);
111     }
112     *bufReturn = currentBuffer.buffer;
113     *lenReturn = currentBuffer.length;
114   }
115
116   void readDataAvailable(size_t len) noexcept override {
117     currentBuffer.length = len;
118     buffers.push_back(currentBuffer);
119     currentBuffer.reset();
120     if (dataAvailableCallback) {
121       dataAvailableCallback();
122     }
123   }
124
125   void readEOF() noexcept override {
126     state = STATE_SUCCEEDED;
127   }
128
129   void readErr(const folly::AsyncSocketException& ex) noexcept override {
130     state = STATE_FAILED;
131     exception = ex;
132   }
133
134   void verifyData(const char* expected, size_t expectedLen) const {
135     size_t offset = 0;
136     for (size_t idx = 0; idx < buffers.size(); ++idx) {
137       const auto& buf = buffers[idx];
138       size_t cmpLen = std::min(buf.length, expectedLen - offset);
139       CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
140       CHECK_EQ(cmpLen, buf.length);
141       offset += cmpLen;
142     }
143     CHECK_EQ(offset, expectedLen);
144   }
145
146   size_t dataRead() const {
147     size_t ret = 0;
148     for (const auto& buf : buffers) {
149       ret += buf.length;
150     }
151     return ret;
152   }
153
154   class Buffer {
155    public:
156     Buffer() : buffer(nullptr), length(0) {}
157     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
158
159     void reset() {
160       buffer = nullptr;
161       length = 0;
162     }
163     void allocate(size_t len) {
164       assert(buffer == nullptr);
165       this->buffer = static_cast<char*>(malloc(len));
166       this->length = len;
167     }
168     void free() {
169       ::free(buffer);
170       reset();
171     }
172
173     char* buffer;
174     size_t length;
175   };
176
177   StateEnum state;
178   folly::AsyncSocketException exception;
179   std::vector<Buffer> buffers;
180   Buffer currentBuffer;
181   VoidCallback dataAvailableCallback;
182   const size_t maxBufferSz;
183 };
184
185 class BufferCallback : public folly::AsyncTransport::BufferCallback {
186  public:
187   BufferCallback() : buffered_(false), bufferCleared_(false) {}
188
189   void onEgressBuffered() override { buffered_ = true; }
190
191   void onEgressBufferCleared() override { bufferCleared_ = true; }
192
193   bool hasBuffered() const { return buffered_; }
194
195   bool hasBufferCleared() const { return bufferCleared_; }
196
197  private:
198   bool buffered_{false};
199   bool bufferCleared_{false};
200 };
201
202 class ReadVerifier {
203 };
204
205 class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback {
206  public:
207   TestErrMessageCallback()
208     : exception_(folly::AsyncSocketException::UNKNOWN, "none")
209   {}
210
211   void errMessage(const cmsghdr& cmsg) noexcept override {
212     if (cmsg.cmsg_level == SOL_SOCKET &&
213       cmsg.cmsg_type == SCM_TIMESTAMPING) {
214       gotTimestamp_ = true;
215     } else if (
216       (cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
217       (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR)) {
218       gotByteSeq_ = true;
219     }
220   }
221
222   void errMessageError(
223       const folly::AsyncSocketException& ex) noexcept override {
224     exception_ = ex;
225   }
226
227   folly::AsyncSocketException exception_;
228   bool gotTimestamp_{false};
229   bool gotByteSeq_{false};
230 };
231
232 class TestSendMsgParamsCallback :
233     public folly::AsyncSocket::SendMsgParamsCallback {
234  public:
235   TestSendMsgParamsCallback(int flags, uint32_t dataSize, void* data)
236   : flags_(flags),
237     writeFlags_(folly::WriteFlags::NONE),
238     dataSize_(dataSize),
239     data_(data),
240     queriedFlags_(false),
241     queriedData_(false)
242   {}
243
244   void reset(int flags) {
245     flags_ = flags;
246     writeFlags_ = folly::WriteFlags::NONE;
247     queriedFlags_ = false;
248     queriedData_ = false;
249   }
250
251   int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
252                                                                   override {
253     queriedFlags_ = true;
254     if (writeFlags_ == folly::WriteFlags::NONE) {
255       writeFlags_ = flags;
256     } else {
257       assert(flags == writeFlags_);
258     }
259     return flags_;
260   }
261
262   void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
263     queriedData_ = true;
264     if (writeFlags_ == folly::WriteFlags::NONE) {
265       writeFlags_ = flags;
266     } else {
267       assert(flags == writeFlags_);
268     }
269     assert(data != nullptr);
270     memcpy(data, data_, dataSize_);
271   }
272
273   uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
274     if (writeFlags_ == folly::WriteFlags::NONE) {
275       writeFlags_ = flags;
276     } else {
277       assert(flags == writeFlags_);
278     }
279     return dataSize_;
280   }
281
282   int flags_;
283   folly::WriteFlags writeFlags_;
284   uint32_t dataSize_;
285   void* data_;
286   bool queriedFlags_;
287   bool queriedData_;
288 };
289
290 class TestServer {
291  public:
292   // Create a TestServer.
293   // This immediately starts listening on an ephemeral port.
294   explicit TestServer(bool enableTFO = false, int bufSize = -1) : fd_(-1) {
295     namespace fsp = folly::portability::sockets;
296     fd_ = fsp::socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
297     if (fd_ < 0) {
298       throw folly::AsyncSocketException(
299           folly::AsyncSocketException::INTERNAL_ERROR,
300           "failed to create test server socket",
301           errno);
302     }
303     if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
304       throw folly::AsyncSocketException(
305           folly::AsyncSocketException::INTERNAL_ERROR,
306           "failed to put test server socket in "
307           "non-blocking mode",
308           errno);
309     }
310     if (enableTFO) {
311 #if FOLLY_ALLOW_TFO
312       folly::detail::tfo_enable(fd_, 100);
313 #endif
314     }
315
316     struct addrinfo hints, *res;
317     memset(&hints, 0, sizeof(hints));
318     hints.ai_family = AF_INET;
319     hints.ai_socktype = SOCK_STREAM;
320     hints.ai_flags = AI_PASSIVE;
321
322     if (getaddrinfo(nullptr, "0", &hints, &res)) {
323       throw folly::AsyncSocketException(
324           folly::AsyncSocketException::INTERNAL_ERROR,
325           "Attempted to bind address to socket with "
326           "bad getaddrinfo",
327           errno);
328     }
329
330     SCOPE_EXIT {
331       freeaddrinfo(res);
332     };
333
334     if (bufSize > 0) {
335       setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufSize, sizeof(bufSize));
336       setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufSize, sizeof(bufSize));
337     }
338
339     if (bind(fd_, res->ai_addr, res->ai_addrlen)) {
340       throw folly::AsyncSocketException(
341           folly::AsyncSocketException::INTERNAL_ERROR,
342           "failed to bind to async server socket for port 10",
343           errno);
344     }
345
346     if (listen(fd_, 10) != 0) {
347       throw folly::AsyncSocketException(
348           folly::AsyncSocketException::INTERNAL_ERROR,
349           "failed to listen on test server socket",
350           errno);
351     }
352
353     address_.setFromLocalAddress(fd_);
354     // The local address will contain 0.0.0.0.
355     // Change it to 127.0.0.1, so it can be used to connect to the server
356     address_.setFromIpPort("127.0.0.1", address_.getPort());
357   }
358
359   ~TestServer() {
360     if (fd_ != -1) {
361       close(fd_);
362     }
363   }
364
365   // Get the address for connecting to the server
366   const folly::SocketAddress& getAddress() const {
367     return address_;
368   }
369
370   int acceptFD(int timeout=50) {
371     namespace fsp = folly::portability::sockets;
372     struct pollfd pfd;
373     pfd.fd = fd_;
374     pfd.events = POLLIN;
375     int ret = poll(&pfd, 1, timeout);
376     if (ret == 0) {
377       throw folly::AsyncSocketException(
378           folly::AsyncSocketException::INTERNAL_ERROR,
379           "test server accept() timed out");
380     } else if (ret < 0) {
381       throw folly::AsyncSocketException(
382           folly::AsyncSocketException::INTERNAL_ERROR,
383           "test server accept() poll failed",
384           errno);
385     }
386
387     int acceptedFd = fsp::accept(fd_, nullptr, nullptr);
388     if (acceptedFd < 0) {
389       throw folly::AsyncSocketException(
390           folly::AsyncSocketException::INTERNAL_ERROR,
391           "test server accept() failed",
392           errno);
393     }
394
395     return acceptedFd;
396   }
397
398   std::shared_ptr<BlockingSocket> accept(int timeout=50) {
399     int fd = acceptFD(timeout);
400     return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
401   }
402
403   std::shared_ptr<folly::AsyncSocket> acceptAsync(folly::EventBase* evb,
404                                                   int timeout = 50) {
405     int fd = acceptFD(timeout);
406     return folly::AsyncSocket::newSocket(evb, fd);
407   }
408
409   /**
410    * Accept a connection, read data from it, and verify that it matches the
411    * data in the specified buffer.
412    */
413   void verifyConnection(const char* buf, size_t len) {
414     // accept a connection
415     std::shared_ptr<BlockingSocket> acceptedSocket = accept();
416     // read the data and compare it to the specified buffer
417     boost::scoped_array<uint8_t> readbuf(new uint8_t[len]);
418     acceptedSocket->readAll(readbuf.get(), len);
419     CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
420     // make sure we get EOF next
421     uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
422     CHECK_EQ(bytesRead, 0);
423   }
424
425  private:
426   int fd_;
427   folly::SocketAddress address_;
428 };