wangle/bootstrap/ClientBootstrap.h \
wangle/channel/AsyncSocketHandler.h \
wangle/channel/EventBaseHandler.h \
+ wangle/channel/FileRegion.h \
wangle/channel/Handler.h \
wangle/channel/HandlerContext.h \
wangle/channel/HandlerContext-inl.h \
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/EventBase.h>
+#include <folly/io/async/EventHandler.h>
+#include <folly/Singleton.h>
#include <folly/SocketAddress.h>
#include <folly/io/IOBuf.h>
#include <errno.h>
#include <limits.h>
#include <unistd.h>
+#include <thread>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/socket.h>
const AsyncSocketException socketShutdownForWritesEx(
AsyncSocketException::END_OF_FILE, "socket shutdown for writes");
-// TODO: It might help performance to provide a version of WriteRequest that
+// TODO: It might help performance to provide a version of BytesWriteRequest that
// users could derive from, so we can avoid the extra allocation for each call
// to write()/writev(). We could templatize TFramedAsyncChannel just like the
// protocols are currently templatized for transports.
// storage space, and only our internal version would allocate it at the end of
// the WriteRequest.
-/**
- * A WriteRequest object tracks information about a pending write operation.
- */
-class AsyncSocket::WriteRequest {
- public:
- WriteRequest(AsyncSocket* socket,
- WriteRequest* next,
- WriteCallback* callback,
- uint32_t totalBytesWritten) :
- socket_(socket), next_(next), callback_(callback),
- totalBytesWritten_(totalBytesWritten) {}
-
- virtual void destroy() = 0;
-
- virtual bool performWrite() = 0;
-
- virtual void consume() = 0;
-
- virtual bool isComplete() = 0;
-
- WriteRequest* getNext() const {
- return next_;
- }
-
- WriteCallback* getCallback() const {
- return callback_;
- }
-
- uint32_t getTotalBytesWritten() const {
- return totalBytesWritten_;
- }
-
- void append(WriteRequest* next) {
- assert(next_ == nullptr);
- next_ = next;
- }
-
- protected:
- // protected destructor, to ensure callers use destroy()
- virtual ~WriteRequest() {}
-
- AsyncSocket* socket_; ///< parent socket
- WriteRequest* next_; ///< pointer to next WriteRequest
- WriteCallback* callback_; ///< completion callback
- uint32_t totalBytesWritten_; ///< total bytes written
-};
-
/* The default WriteRequest implementation, used for write(), writev() and
* writeChain()
*
uint32_t bytesWritten,
unique_ptr<IOBuf>&& ioBuf,
WriteFlags flags)
- : AsyncSocket::WriteRequest(socket, nullptr, callback, 0)
+ : AsyncSocket::WriteRequest(socket, callback)
, opCount_(opCount)
, opIndex_(0)
, flags_(flags)
}
}
+void AsyncSocket::writeRequest(WriteRequest* req) {
+ if (writeReqTail_ == nullptr) {
+ assert(writeReqHead_ == nullptr);
+ writeReqHead_ = writeReqTail_ = req;
+ req->start();
+ } else {
+ writeReqTail_->append(req);
+ writeReqTail_ = req;
+ }
+}
+
void AsyncSocket::close() {
VLOG(5) << "AsyncSocket::close(): this=" << this << ", fd_=" << fd_
<< ", state=" << state_ << ", shutdownFlags="
std::unique_ptr<folly::IOBuf>&& buf,
WriteFlags flags = WriteFlags::NONE) override;
+ class WriteRequest;
+ virtual void writeRequest(WriteRequest* req);
+ void writeRequestReady() {
+ handleWrite();
+ }
+
// Methods inherited from AsyncTransport
void close() override;
void closeNow() override;
ERROR
};
+ /**
+ * A WriteRequest object tracks information about a pending write operation.
+ */
+ class WriteRequest {
+ public:
+ WriteRequest(AsyncSocket* socket, WriteCallback* callback) :
+ socket_(socket), callback_(callback) {}
+
+ virtual void start() {};
+
+ virtual void destroy() = 0;
+
+ virtual bool performWrite() = 0;
+
+ virtual void consume() = 0;
+
+ virtual bool isComplete() = 0;
+
+ WriteRequest* getNext() const {
+ return next_;
+ }
+
+ WriteCallback* getCallback() const {
+ return callback_;
+ }
+
+ uint32_t getTotalBytesWritten() const {
+ return totalBytesWritten_;
+ }
+
+ void append(WriteRequest* next) {
+ assert(next_ == nullptr);
+ next_ = next;
+ }
+
+ void fail(const char* fn, const AsyncSocketException& ex) {
+ socket_->failWrite(fn, ex);
+ }
+
+ void bytesWritten(size_t count) {
+ totalBytesWritten_ += count;
+ socket_->appBytesWritten_ += count;
+ }
+
+ protected:
+ // protected destructor, to ensure callers use destroy()
+ virtual ~WriteRequest() {}
+
+ AsyncSocket* socket_; ///< parent socket
+ WriteRequest* next_{nullptr}; ///< pointer to next WriteRequest
+ WriteCallback* callback_; ///< completion callback
+ uint32_t totalBytesWritten_{0}; ///< total bytes written
+ };
+
protected:
enum ReadResultEnum {
READ_EOF = 0,
SHUT_READ = 0x04,
};
- class WriteRequest;
class BytesWriteRequest;
class WriteTimeout : public AsyncTimeout {
--- /dev/null
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/test/BlockingSocket.h>
+
+#include <boost/scoped_array.hpp>
+#include <poll.h>
+
+// This is a test-only header
+/* using override */
+using namespace folly;
+
+enum StateEnum {
+ STATE_WAITING,
+ STATE_SUCCEEDED,
+ STATE_FAILED
+};
+
+typedef std::function<void()> VoidCallback;
+
+class ConnCallback : public AsyncSocket::ConnectCallback {
+ public:
+ ConnCallback()
+ : state(STATE_WAITING)
+ , exception(AsyncSocketException::UNKNOWN, "none") {}
+
+ void connectSuccess() noexcept override {
+ state = STATE_SUCCEEDED;
+ if (successCallback) {
+ successCallback();
+ }
+ }
+
+ void connectErr(const AsyncSocketException& ex) noexcept override {
+ state = STATE_FAILED;
+ exception = ex;
+ if (errorCallback) {
+ errorCallback();
+ }
+ }
+
+ StateEnum state;
+ AsyncSocketException exception;
+ VoidCallback successCallback;
+ VoidCallback errorCallback;
+};
+
+class WriteCallback : public AsyncTransportWrapper::WriteCallback {
+ public:
+ WriteCallback()
+ : state(STATE_WAITING)
+ , bytesWritten(0)
+ , exception(AsyncSocketException::UNKNOWN, "none") {}
+
+ void writeSuccess() noexcept override {
+ state = STATE_SUCCEEDED;
+ if (successCallback) {
+ successCallback();
+ }
+ }
+
+ void writeErr(size_t bytesWritten,
+ const AsyncSocketException& ex) noexcept override {
+ state = STATE_FAILED;
+ this->bytesWritten = bytesWritten;
+ exception = ex;
+ if (errorCallback) {
+ errorCallback();
+ }
+ }
+
+ StateEnum state;
+ size_t bytesWritten;
+ AsyncSocketException exception;
+ VoidCallback successCallback;
+ VoidCallback errorCallback;
+};
+
+class ReadCallback : public AsyncTransportWrapper::ReadCallback {
+ public:
+ ReadCallback()
+ : state(STATE_WAITING)
+ , exception(AsyncSocketException::UNKNOWN, "none")
+ , buffers() {}
+
+ ~ReadCallback() {
+ for (std::vector<Buffer>::iterator it = buffers.begin();
+ it != buffers.end();
+ ++it) {
+ it->free();
+ }
+ currentBuffer.free();
+ }
+
+ void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ if (!currentBuffer.buffer) {
+ currentBuffer.allocate(4096);
+ }
+ *bufReturn = currentBuffer.buffer;
+ *lenReturn = currentBuffer.length;
+ }
+
+ void readDataAvailable(size_t len) noexcept override {
+ currentBuffer.length = len;
+ buffers.push_back(currentBuffer);
+ currentBuffer.reset();
+ if (dataAvailableCallback) {
+ dataAvailableCallback();
+ }
+ }
+
+ void readEOF() noexcept override {
+ state = STATE_SUCCEEDED;
+ }
+
+ void readErr(const AsyncSocketException& ex) noexcept override {
+ state = STATE_FAILED;
+ exception = ex;
+ }
+
+ void verifyData(const char* expected, size_t expectedLen) const {
+ size_t offset = 0;
+ for (size_t idx = 0; idx < buffers.size(); ++idx) {
+ const auto& buf = buffers[idx];
+ size_t cmpLen = std::min(buf.length, expectedLen - offset);
+ CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
+ CHECK_EQ(cmpLen, buf.length);
+ offset += cmpLen;
+ }
+ CHECK_EQ(offset, expectedLen);
+ }
+
+ class Buffer {
+ public:
+ Buffer() : buffer(nullptr), length(0) {}
+ Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
+
+ void reset() {
+ buffer = nullptr;
+ length = 0;
+ }
+ void allocate(size_t length) {
+ assert(buffer == nullptr);
+ this->buffer = static_cast<char*>(malloc(length));
+ this->length = length;
+ }
+ void free() {
+ ::free(buffer);
+ reset();
+ }
+
+ char* buffer;
+ size_t length;
+ };
+
+ StateEnum state;
+ AsyncSocketException exception;
+ std::vector<Buffer> buffers;
+ Buffer currentBuffer;
+ VoidCallback dataAvailableCallback;
+};
+
+class ReadVerifier {
+};
+
+class TestServer {
+ public:
+ // Create a TestServer.
+ // This immediately starts listening on an ephemeral port.
+ TestServer()
+ : fd_(-1) {
+ fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
+ if (fd_ < 0) {
+ throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+ "failed to create test server socket", errno);
+ }
+ if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
+ throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+ "failed to put test server socket in "
+ "non-blocking mode", errno);
+ }
+ if (listen(fd_, 10) != 0) {
+ throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+ "failed to listen on test server socket",
+ errno);
+ }
+
+ address_.setFromLocalAddress(fd_);
+ // The local address will contain 0.0.0.0.
+ // Change it to 127.0.0.1, so it can be used to connect to the server
+ address_.setFromIpPort("127.0.0.1", address_.getPort());
+ }
+
+ // Get the address for connecting to the server
+ const folly::SocketAddress& getAddress() const {
+ return address_;
+ }
+
+ int acceptFD(int timeout=50) {
+ struct pollfd pfd;
+ pfd.fd = fd_;
+ pfd.events = POLLIN;
+ int ret = poll(&pfd, 1, timeout);
+ if (ret == 0) {
+ throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+ "test server accept() timed out");
+ } else if (ret < 0) {
+ throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+ "test server accept() poll failed", errno);
+ }
+
+ int acceptedFd = ::accept(fd_, nullptr, nullptr);
+ if (acceptedFd < 0) {
+ throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+ "test server accept() failed", errno);
+ }
+
+ return acceptedFd;
+ }
+
+ std::shared_ptr<BlockingSocket> accept(int timeout=50) {
+ int fd = acceptFD(timeout);
+ return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
+ }
+
+ std::shared_ptr<AsyncSocket> acceptAsync(EventBase* evb, int timeout=50) {
+ int fd = acceptFD(timeout);
+ return AsyncSocket::newSocket(evb, fd);
+ }
+
+ /**
+ * Accept a connection, read data from it, and verify that it matches the
+ * data in the specified buffer.
+ */
+ void verifyConnection(const char* buf, size_t len) {
+ // accept a connection
+ std::shared_ptr<BlockingSocket> acceptedSocket = accept();
+ // read the data and compare it to the specified buffer
+ boost::scoped_array<uint8_t> readbuf(new uint8_t[len]);
+ acceptedSocket->readAll(readbuf.get(), len);
+ CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
+ // make sure we get EOF next
+ uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
+ CHECK_EQ(bytesRead, 0);
+ }
+
+ private:
+ int fd_;
+ folly::SocketAddress address_;
+};
#include <folly/SocketAddress.h>
#include <folly/io/IOBuf.h>
-#include <folly/io/async/test/BlockingSocket.h>
+#include <folly/io/async/test/AsyncSocketTest.h>
#include <folly/io/async/test/Util.h>
#include <gtest/gtest.h>
using namespace folly;
-enum StateEnum {
- STATE_WAITING,
- STATE_SUCCEEDED,
- STATE_FAILED
-};
-
-typedef std::function<void()> VoidCallback;
-
-
-class ConnCallback : public AsyncSocket::ConnectCallback {
- public:
- ConnCallback()
- : state(STATE_WAITING)
- , exception(AsyncSocketException::UNKNOWN, "none") {}
-
- void connectSuccess() noexcept override {
- state = STATE_SUCCEEDED;
- if (successCallback) {
- successCallback();
- }
- }
-
- void connectErr(const AsyncSocketException& ex) noexcept override {
- state = STATE_FAILED;
- exception = ex;
- if (errorCallback) {
- errorCallback();
- }
- }
-
- StateEnum state;
- AsyncSocketException exception;
- VoidCallback successCallback;
- VoidCallback errorCallback;
-};
-
-class WriteCallback : public AsyncTransportWrapper::WriteCallback {
- public:
- WriteCallback()
- : state(STATE_WAITING)
- , bytesWritten(0)
- , exception(AsyncSocketException::UNKNOWN, "none") {}
-
- void writeSuccess() noexcept override {
- state = STATE_SUCCEEDED;
- if (successCallback) {
- successCallback();
- }
- }
-
- void writeErr(size_t bytesWritten,
- const AsyncSocketException& ex) noexcept override {
- state = STATE_FAILED;
- this->bytesWritten = bytesWritten;
- exception = ex;
- if (errorCallback) {
- errorCallback();
- }
- }
-
- StateEnum state;
- size_t bytesWritten;
- AsyncSocketException exception;
- VoidCallback successCallback;
- VoidCallback errorCallback;
-};
-
-class ReadCallback : public AsyncTransportWrapper::ReadCallback {
- public:
- ReadCallback()
- : state(STATE_WAITING)
- , exception(AsyncSocketException::UNKNOWN, "none")
- , buffers() {}
-
- ~ReadCallback() {
- for (vector<Buffer>::iterator it = buffers.begin();
- it != buffers.end();
- ++it) {
- it->free();
- }
- currentBuffer.free();
- }
-
- void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
- if (!currentBuffer.buffer) {
- currentBuffer.allocate(4096);
- }
- *bufReturn = currentBuffer.buffer;
- *lenReturn = currentBuffer.length;
- }
-
- void readDataAvailable(size_t len) noexcept override {
- currentBuffer.length = len;
- buffers.push_back(currentBuffer);
- currentBuffer.reset();
- if (dataAvailableCallback) {
- dataAvailableCallback();
- }
- }
-
- void readEOF() noexcept override {
- state = STATE_SUCCEEDED;
- }
-
- void readErr(const AsyncSocketException& ex) noexcept override {
- state = STATE_FAILED;
- exception = ex;
- }
-
- void verifyData(const char* expected, size_t expectedLen) const {
- size_t offset = 0;
- for (size_t idx = 0; idx < buffers.size(); ++idx) {
- const auto& buf = buffers[idx];
- size_t cmpLen = std::min(buf.length, expectedLen - offset);
- CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
- CHECK_EQ(cmpLen, buf.length);
- offset += cmpLen;
- }
- CHECK_EQ(offset, expectedLen);
- }
-
- class Buffer {
- public:
- Buffer() : buffer(nullptr), length(0) {}
- Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
-
- void reset() {
- buffer = nullptr;
- length = 0;
- }
- void allocate(size_t length) {
- assert(buffer == nullptr);
- this->buffer = static_cast<char*>(malloc(length));
- this->length = length;
- }
- void free() {
- ::free(buffer);
- reset();
- }
-
- char* buffer;
- size_t length;
- };
-
- StateEnum state;
- AsyncSocketException exception;
- vector<Buffer> buffers;
- Buffer currentBuffer;
- VoidCallback dataAvailableCallback;
-};
-
-class ReadVerifier {
-};
-
-class TestServer {
- public:
- // Create a TestServer.
- // This immediately starts listening on an ephemeral port.
- TestServer()
- : fd_(-1) {
- fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
- if (fd_ < 0) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- "failed to create test server socket", errno);
- }
- if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- "failed to put test server socket in "
- "non-blocking mode", errno);
- }
- if (listen(fd_, 10) != 0) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- "failed to listen on test server socket",
- errno);
- }
-
- address_.setFromLocalAddress(fd_);
- // The local address will contain 0.0.0.0.
- // Change it to 127.0.0.1, so it can be used to connect to the server
- address_.setFromIpPort("127.0.0.1", address_.getPort());
- }
-
- // Get the address for connecting to the server
- const folly::SocketAddress& getAddress() const {
- return address_;
- }
-
- int acceptFD(int timeout=50) {
- struct pollfd pfd;
- pfd.fd = fd_;
- pfd.events = POLLIN;
- int ret = poll(&pfd, 1, timeout);
- if (ret == 0) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- "test server accept() timed out");
- } else if (ret < 0) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- "test server accept() poll failed", errno);
- }
-
- int acceptedFd = ::accept(fd_, nullptr, nullptr);
- if (acceptedFd < 0) {
- throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
- "test server accept() failed", errno);
- }
-
- return acceptedFd;
- }
-
- std::shared_ptr<BlockingSocket> accept(int timeout=50) {
- int fd = acceptFD(timeout);
- return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
- }
-
- std::shared_ptr<AsyncSocket> acceptAsync(EventBase* evb, int timeout=50) {
- int fd = acceptFD(timeout);
- return AsyncSocket::newSocket(evb, fd);
- }
-
- /**
- * Accept a connection, read data from it, and verify that it matches the
- * data in the specified buffer.
- */
- void verifyConnection(const char* buf, size_t len) {
- // accept a connection
- std::shared_ptr<BlockingSocket> acceptedSocket = accept();
- // read the data and compare it to the specified buffer
- scoped_array<uint8_t> readbuf(new uint8_t[len]);
- acceptedSocket->readAll(readbuf.get(), len);
- CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
- // make sure we get EOF next
- uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
- CHECK_EQ(bytesRead, 0);
- }
-
- private:
- int fd_;
- folly::SocketAddress address_;
-};
-
class DelayedWrite: public AsyncTimeout {
public:
DelayedWrite(const std::shared_ptr<AsyncSocket>& socket,
--- /dev/null
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <folly/wangle/channel/FileRegion.h>
+
+using namespace folly;
+using namespace folly::wangle;
+
+namespace {
+
+struct FileRegionReadPool {};
+
+Singleton<IOThreadPoolExecutor, FileRegionReadPool> readPool(
+ []{
+ return new IOThreadPoolExecutor(
+ sysconf(_SC_NPROCESSORS_ONLN),
+ std::make_shared<NamedThreadFactory>("FileRegionReadPool"));
+ });
+
+}
+
+namespace folly { namespace wangle {
+
+FileRegion::FileWriteRequest::FileWriteRequest(AsyncSocket* socket,
+ WriteCallback* callback, int fd, off_t offset, size_t count)
+ : WriteRequest(socket, callback),
+ readFd_(fd), offset_(offset), count_(count) {
+}
+
+void FileRegion::FileWriteRequest::destroy() {
+ readBase_->runInEventBaseThread([this]{
+ delete this;
+ });
+}
+
+bool FileRegion::FileWriteRequest::performWrite() {
+ if (!started_) {
+ start();
+ return true;
+ }
+
+ int flags = SPLICE_F_NONBLOCK | SPLICE_F_MORE;
+ ssize_t spliced = ::splice(pipe_out_, nullptr,
+ socket_->getFd(), nullptr,
+ bytesInPipe_, flags);
+ if (spliced == -1) {
+ if (errno == EAGAIN) {
+ return true;
+ }
+ return false;
+ }
+
+ bytesInPipe_ -= spliced;
+ bytesWritten(spliced);
+ return true;
+}
+
+void FileRegion::FileWriteRequest::consume() {
+ // do nothing
+}
+
+bool FileRegion::FileWriteRequest::isComplete() {
+ return totalBytesWritten_ == count_;
+}
+
+void FileRegion::FileWriteRequest::messageAvailable(size_t&& count) {
+ bool shouldWrite = bytesInPipe_ == 0;
+ bytesInPipe_ += count;
+ if (shouldWrite) {
+ socket_->writeRequestReady();
+ }
+}
+
+#ifdef __GLIBC__
+# if (__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 9))
+# define GLIBC_AT_LEAST_2_9 1
+# endif
+#endif
+
+void FileRegion::FileWriteRequest::start() {
+ started_ = true;
+ readBase_ = readPool.get()->getEventBase();
+ readBase_->runInEventBaseThread([this]{
+ auto flags = fcntl(readFd_, F_GETFL);
+ if (flags == -1) {
+ fail(__func__, AsyncSocketException(
+ AsyncSocketException::INTERNAL_ERROR,
+ "fcntl F_GETFL failed", errno));
+ return;
+ }
+
+ flags &= O_ACCMODE;
+ if (flags == O_WRONLY) {
+ fail(__func__, AsyncSocketException(
+ AsyncSocketException::BAD_ARGS, "file not open for reading"));
+ return;
+ }
+
+#ifndef GLIBC_AT_LEAST_2_9
+ fail(__func__, AsyncSocketException(
+ AsyncSocketException::NOT_SUPPORTED,
+ "writeFile unsupported on glibc < 2.9"));
+ return;
+#else
+ int pipeFds[2];
+ if (::pipe2(pipeFds, O_NONBLOCK) == -1) {
+ fail(__func__, AsyncSocketException(
+ AsyncSocketException::INTERNAL_ERROR,
+ "pipe2 failed", errno));
+ return;
+ }
+
+ // Max size for unprevileged processes as set in /proc/sys/fs/pipe-max-size
+ // Ignore failures and just roll with it
+ // TODO maybe read max size from /proc?
+ fcntl(pipeFds[0], F_SETPIPE_SZ, 1048576);
+ fcntl(pipeFds[1], F_SETPIPE_SZ, 1048576);
+
+ pipe_out_ = pipeFds[0];
+
+ socket_->getEventBase()->runInEventBaseThreadAndWait([&]{
+ startConsuming(socket_->getEventBase(), &queue_);
+ });
+ readHandler_ = folly::make_unique<FileReadHandler>(
+ this, pipeFds[1], count_);
+#endif
+ });
+}
+
+FileRegion::FileWriteRequest::~FileWriteRequest() {
+ CHECK(readBase_->isInEventBaseThread());
+ socket_->getEventBase()->runInEventBaseThreadAndWait([&]{
+ stopConsuming();
+ if (pipe_out_ > -1) {
+ ::close(pipe_out_);
+ }
+ });
+
+}
+
+void FileRegion::FileWriteRequest::fail(
+ const char* fn,
+ const AsyncSocketException& ex) {
+ socket_->getEventBase()->runInEventBaseThread([=]{
+ WriteRequest::fail(fn, ex);
+ });
+}
+
+FileRegion::FileWriteRequest::FileReadHandler::FileReadHandler(
+ FileWriteRequest* req, int pipe_in, size_t bytesToRead)
+ : req_(req), pipe_in_(pipe_in), bytesToRead_(bytesToRead) {
+ CHECK(req_->readBase_->isInEventBaseThread());
+ initHandler(req_->readBase_, pipe_in);
+ if (!registerHandler(EventFlags::WRITE | EventFlags::PERSIST)) {
+ req_->fail(__func__, AsyncSocketException(
+ AsyncSocketException::INTERNAL_ERROR,
+ "registerHandler failed"));
+ }
+}
+
+FileRegion::FileWriteRequest::FileReadHandler::~FileReadHandler() {
+ CHECK(req_->readBase_->isInEventBaseThread());
+ unregisterHandler();
+ ::close(pipe_in_);
+}
+
+void FileRegion::FileWriteRequest::FileReadHandler::handlerReady(
+ uint16_t events) noexcept {
+ CHECK(events & EventHandler::WRITE);
+ if (bytesToRead_ == 0) {
+ unregisterHandler();
+ return;
+ }
+
+ int flags = SPLICE_F_NONBLOCK | SPLICE_F_MORE;
+ ssize_t spliced = ::splice(req_->readFd_, &req_->offset_,
+ pipe_in_, nullptr,
+ bytesToRead_, flags);
+ if (spliced == -1) {
+ if (errno == EAGAIN) {
+ return;
+ } else {
+ req_->fail(__func__, AsyncSocketException(
+ AsyncSocketException::INTERNAL_ERROR,
+ "splice failed", errno));
+ return;
+ }
+ }
+
+ if (spliced > 0) {
+ bytesToRead_ -= spliced;
+ try {
+ req_->queue_.putMessage(static_cast<size_t>(spliced));
+ } catch (...) {
+ req_->fail(__func__, AsyncSocketException(
+ AsyncSocketException::INTERNAL_ERROR,
+ "putMessage failed"));
+ return;
+ }
+ }
+}
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <folly/Singleton.h>
+#include <folly/io/async/AsyncTransport.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/NotificationQueue.h>
+#include <folly/futures/Future.h>
+#include <folly/futures/Promise.h>
+#include <folly/wangle/concurrent/IOThreadPoolExecutor.h>
+
+namespace folly { namespace wangle {
+
+class FileRegion {
+ public:
+ FileRegion(int fd, off_t offset, size_t count)
+ : fd_(fd), offset_(offset), count_(count) {}
+
+ Future<void> transferTo(std::shared_ptr<AsyncTransport> transport) {
+ auto socket = std::dynamic_pointer_cast<AsyncSocket>(
+ transport);
+ CHECK(socket);
+ auto cb = new WriteCallback();
+ auto f = cb->promise_.getFuture();
+ auto req = new FileWriteRequest(socket.get(), cb, fd_, offset_, count_);
+ socket->writeRequest(req);
+ return f;
+ }
+
+ private:
+ class WriteCallback : private AsyncSocket::WriteCallback {
+ void writeSuccess() noexcept override {
+ promise_.setValue();
+ delete this;
+ }
+
+ void writeErr(size_t bytesWritten,
+ const AsyncSocketException& ex)
+ noexcept override {
+ promise_.setException(ex);
+ delete this;
+ }
+
+ friend class FileRegion;
+ folly::Promise<void> promise_;
+ };
+
+ const int fd_;
+ const off_t offset_;
+ const size_t count_;
+
+ class FileWriteRequest : public AsyncSocket::WriteRequest,
+ public NotificationQueue<size_t>::Consumer {
+ public:
+ FileWriteRequest(AsyncSocket* socket, WriteCallback* callback,
+ int fd, off_t offset, size_t count);
+
+ void destroy() override;
+
+ bool performWrite() override;
+
+ void consume() override;
+
+ bool isComplete() override;
+
+ void messageAvailable(size_t&& count) override;
+
+ void start() override;
+
+ class FileReadHandler : public folly::EventHandler {
+ public:
+ FileReadHandler(FileWriteRequest* req, int pipe_in, size_t bytesToRead);
+
+ ~FileReadHandler();
+
+ void handlerReady(uint16_t events) noexcept override;
+
+ private:
+ FileWriteRequest* req_;
+ int pipe_in_;
+ size_t bytesToRead_;
+ };
+
+ private:
+ ~FileWriteRequest();
+
+ void fail(const char* fn, const AsyncSocketException& ex);
+
+ const int readFd_;
+ off_t offset_;
+ const size_t count_;
+ bool started_{false};
+ int pipe_out_{-1};
+
+ size_t bytesInPipe_{0};
+ folly::EventBase* readBase_;
+ folly::NotificationQueue<size_t> queue_;
+ std::unique_ptr<FileReadHandler> readHandler_;
+ };
+};
+
+}} // folly::wangle
--- /dev/null
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <folly/wangle/channel/FileRegion.h>
+#include <folly/io/async/test/AsyncSocketTest.h>
+#include <gtest/gtest.h>
+
+using namespace folly;
+using namespace folly::wangle;
+using namespace testing;
+
+struct FileRegionTest : public Test {
+ FileRegionTest() {
+ // Connect
+ socket = AsyncSocket::newSocket(&evb);
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ // Accept the connection
+ acceptedSocket = server.acceptAsync(&evb);
+ acceptedSocket->setReadCB(&rcb);
+
+ // Create temp file
+ char path[] = "/tmp/AsyncSocketTest.WriteFile.XXXXXX";
+ fd = mkostemp(path, O_RDWR);
+ EXPECT_TRUE(fd > 0);
+ EXPECT_EQ(0, unlink(path));
+ }
+
+ ~FileRegionTest() {
+ // Close up shop
+ close(fd);
+ acceptedSocket->close();
+ socket->close();
+ }
+
+ TestServer server;
+ EventBase evb;
+ std::shared_ptr<AsyncSocket> socket;
+ std::shared_ptr<AsyncSocket> acceptedSocket;
+ ConnCallback ccb;
+ ReadCallback rcb;
+ int fd;
+};
+
+TEST_F(FileRegionTest, Basic) {
+ size_t count = 1000000000; // 1 GB
+ void* zeroBuf = calloc(1, count);
+ write(fd, zeroBuf, count);
+
+ FileRegion fileRegion(fd, 0, count);
+ auto f = fileRegion.transferTo(socket);
+ try {
+ f.getVia(&evb);
+ } catch (std::exception& e) {
+ LOG(FATAL) << exceptionStr(e);
+ }
+
+ // Let the reads run to completion
+ socket->shutdownWrite();
+ evb.loop();
+
+ ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
+
+ size_t receivedBytes = 0;
+ for (auto& buf : rcb.buffers) {
+ receivedBytes += buf.length;
+ ASSERT_EQ(memcmp(buf.buffer, zeroBuf, buf.length), 0);
+ }
+ ASSERT_EQ(receivedBytes, count);
+}
+
+TEST_F(FileRegionTest, Repeated) {
+ size_t count = 1000000;
+ void* zeroBuf = calloc(1, count);
+ write(fd, zeroBuf, count);
+
+ int sendCount = 1000;
+
+ FileRegion fileRegion(fd, 0, count);
+ std::vector<Future<void>> fs;
+ for (int i = 0; i < sendCount; i++) {
+ fs.push_back(fileRegion.transferTo(socket));
+ }
+ auto f = collect(fs);
+ ASSERT_NO_THROW(f.getVia(&evb));
+
+ // Let the reads run to completion
+ socket->shutdownWrite();
+ evb.loop();
+
+ ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
+
+ size_t receivedBytes = 0;
+ for (auto& buf : rcb.buffers) {
+ receivedBytes += buf.length;
+ }
+ ASSERT_EQ(receivedBytes, sendCount*count);
+}