2 * Copyright 2015-present Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 #include <folly/io/async/AsyncSocket.h>
19 #include <folly/io/async/test/BlockingSocket.h>
20 #include <folly/portability/Sockets.h>
22 #include <boost/scoped_array.hpp>
31 typedef std::function<void()> VoidCallback;
33 class ConnCallback : public folly::AsyncSocket::ConnectCallback {
36 : state(STATE_WAITING),
37 exception(folly::AsyncSocketException::UNKNOWN, "none") {}
39 void connectSuccess() noexcept override {
40 state = STATE_SUCCEEDED;
41 if (successCallback) {
46 void connectErr(const folly::AsyncSocketException& ex) noexcept override {
55 folly::AsyncSocketException exception;
56 VoidCallback successCallback;
57 VoidCallback errorCallback;
60 class WriteCallback : public folly::AsyncTransportWrapper::WriteCallback {
63 : state(STATE_WAITING),
65 exception(folly::AsyncSocketException::UNKNOWN, "none") {}
67 void writeSuccess() noexcept override {
68 state = STATE_SUCCEEDED;
69 if (successCallback) {
74 void writeErr(size_t nBytesWritten,
75 const folly::AsyncSocketException& ex) noexcept override {
76 LOG(ERROR) << ex.what();
78 this->bytesWritten = nBytesWritten;
86 std::atomic<size_t> bytesWritten;
87 folly::AsyncSocketException exception;
88 VoidCallback successCallback;
89 VoidCallback errorCallback;
92 class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback {
94 explicit ReadCallback(size_t _maxBufferSz = 4096)
95 : state(STATE_WAITING),
96 exception(folly::AsyncSocketException::UNKNOWN, "none"),
98 maxBufferSz(_maxBufferSz) {}
100 ~ReadCallback() override {
101 for (std::vector<Buffer>::iterator it = buffers.begin();
106 currentBuffer.free();
109 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
110 if (!currentBuffer.buffer) {
111 currentBuffer.allocate(maxBufferSz);
113 *bufReturn = currentBuffer.buffer;
114 *lenReturn = currentBuffer.length;
117 void readDataAvailable(size_t len) noexcept override {
118 currentBuffer.length = len;
119 buffers.push_back(currentBuffer);
120 currentBuffer.reset();
121 if (dataAvailableCallback) {
122 dataAvailableCallback();
126 void readEOF() noexcept override {
127 state = STATE_SUCCEEDED;
130 void readErr(const folly::AsyncSocketException& ex) noexcept override {
131 state = STATE_FAILED;
135 void verifyData(const char* expected, size_t expectedLen) const {
137 for (size_t idx = 0; idx < buffers.size(); ++idx) {
138 const auto& buf = buffers[idx];
139 size_t cmpLen = std::min(buf.length, expectedLen - offset);
140 CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
141 CHECK_EQ(cmpLen, buf.length);
144 CHECK_EQ(offset, expectedLen);
147 size_t dataRead() const {
149 for (const auto& buf : buffers) {
157 Buffer() : buffer(nullptr), length(0) {}
158 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
164 void allocate(size_t len) {
165 assert(buffer == nullptr);
166 this->buffer = static_cast<char*>(malloc(len));
179 folly::AsyncSocketException exception;
180 std::vector<Buffer> buffers;
181 Buffer currentBuffer;
182 VoidCallback dataAvailableCallback;
183 const size_t maxBufferSz;
186 class BufferCallback : public folly::AsyncTransport::BufferCallback {
188 BufferCallback() : buffered_(false), bufferCleared_(false) {}
190 void onEgressBuffered() override { buffered_ = true; }
192 void onEgressBufferCleared() override { bufferCleared_ = true; }
194 bool hasBuffered() const { return buffered_; }
196 bool hasBufferCleared() const { return bufferCleared_; }
199 bool buffered_{false};
200 bool bufferCleared_{false};
206 class TestSendMsgParamsCallback :
207 public folly::AsyncSocket::SendMsgParamsCallback {
209 TestSendMsgParamsCallback(int flags, uint32_t dataSize, void* data)
211 writeFlags_(folly::WriteFlags::NONE),
214 queriedFlags_(false),
218 void reset(int flags) {
220 writeFlags_ = folly::WriteFlags::NONE;
221 queriedFlags_ = false;
222 queriedData_ = false;
225 int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
227 queriedFlags_ = true;
228 if (writeFlags_ == folly::WriteFlags::NONE) {
231 assert(flags == writeFlags_);
236 void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
238 if (writeFlags_ == folly::WriteFlags::NONE) {
241 assert(flags == writeFlags_);
243 assert(data != nullptr);
244 memcpy(data, data_, dataSize_);
247 uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
248 if (writeFlags_ == folly::WriteFlags::NONE) {
251 assert(flags == writeFlags_);
257 folly::WriteFlags writeFlags_;
266 // Create a TestServer.
267 // This immediately starts listening on an ephemeral port.
268 explicit TestServer(bool enableTFO = false, int bufSize = -1) : fd_(-1) {
269 namespace fsp = folly::portability::sockets;
270 fd_ = fsp::socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
272 throw folly::AsyncSocketException(
273 folly::AsyncSocketException::INTERNAL_ERROR,
274 "failed to create test server socket",
277 if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
278 throw folly::AsyncSocketException(
279 folly::AsyncSocketException::INTERNAL_ERROR,
280 "failed to put test server socket in "
286 folly::detail::tfo_enable(fd_, 100);
290 struct addrinfo hints, *res;
291 memset(&hints, 0, sizeof(hints));
292 hints.ai_family = AF_INET;
293 hints.ai_socktype = SOCK_STREAM;
294 hints.ai_flags = AI_PASSIVE;
296 if (getaddrinfo(nullptr, "0", &hints, &res)) {
297 throw folly::AsyncSocketException(
298 folly::AsyncSocketException::INTERNAL_ERROR,
299 "Attempted to bind address to socket with "
309 setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufSize, sizeof(bufSize));
310 setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufSize, sizeof(bufSize));
313 if (bind(fd_, res->ai_addr, res->ai_addrlen)) {
314 throw folly::AsyncSocketException(
315 folly::AsyncSocketException::INTERNAL_ERROR,
316 "failed to bind to async server socket for port 10",
320 if (listen(fd_, 10) != 0) {
321 throw folly::AsyncSocketException(
322 folly::AsyncSocketException::INTERNAL_ERROR,
323 "failed to listen on test server socket",
327 address_.setFromLocalAddress(fd_);
328 // The local address will contain 0.0.0.0.
329 // Change it to 127.0.0.1, so it can be used to connect to the server
330 address_.setFromIpPort("127.0.0.1", address_.getPort());
339 // Get the address for connecting to the server
340 const folly::SocketAddress& getAddress() const {
344 int acceptFD(int timeout=50) {
345 namespace fsp = folly::portability::sockets;
349 int ret = poll(&pfd, 1, timeout);
351 throw folly::AsyncSocketException(
352 folly::AsyncSocketException::INTERNAL_ERROR,
353 "test server accept() timed out");
354 } else if (ret < 0) {
355 throw folly::AsyncSocketException(
356 folly::AsyncSocketException::INTERNAL_ERROR,
357 "test server accept() poll failed",
361 int acceptedFd = fsp::accept(fd_, nullptr, nullptr);
362 if (acceptedFd < 0) {
363 throw folly::AsyncSocketException(
364 folly::AsyncSocketException::INTERNAL_ERROR,
365 "test server accept() failed",
372 std::shared_ptr<BlockingSocket> accept(int timeout=50) {
373 int fd = acceptFD(timeout);
374 return std::make_shared<BlockingSocket>(fd);
377 std::shared_ptr<folly::AsyncSocket> acceptAsync(folly::EventBase* evb,
379 int fd = acceptFD(timeout);
380 return folly::AsyncSocket::newSocket(evb, fd);
384 * Accept a connection, read data from it, and verify that it matches the
385 * data in the specified buffer.
387 void verifyConnection(const char* buf, size_t len) {
388 // accept a connection
389 std::shared_ptr<BlockingSocket> acceptedSocket = accept();
390 // read the data and compare it to the specified buffer
391 boost::scoped_array<uint8_t> readbuf(new uint8_t[len]);
392 acceptedSocket->readAll(readbuf.get(), len);
393 CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
394 // make sure we get EOF next
395 uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
396 CHECK_EQ(bytesRead, 0);
401 folly::SocketAddress address_;