From: James Sedgwick Date: Mon, 8 Jun 2015 15:41:33 +0000 (-0700) Subject: AsyncSocket::writeRequest() and its first user wangle::FileRegion X-Git-Tag: v0.45.0~12 X-Git-Url: http://demsky.eecs.uci.edu/git/?a=commitdiff_plain;h=7dbfd2f877ab8386992e3a8d540e14b3e567e33e;p=folly.git AsyncSocket::writeRequest() and its first user wangle::FileRegion Summary: similar to D2050808, but move the functionality into AsyncSocket itself so that you have a consistent interface and contiguous writes for a single file Test Plan: added unit, will hook this up to a file server example next Reviewed By: davejwatson@fb.com Subscribers: fugalh, net-systems@, folly-diffs@, jsedgwick, yfeldblum, chalfant FB internal diff: D2084452 Signature: t1:2084452:1433181933:175158618966706db00bf6620cc86ae145d04ecf --- diff --git a/folly/Makefile.am b/folly/Makefile.am index d2377fa5..b937ae93 100644 --- a/folly/Makefile.am +++ b/folly/Makefile.am @@ -282,6 +282,7 @@ nobase_follyinclude_HEADERS = \ wangle/bootstrap/ClientBootstrap.h \ wangle/channel/AsyncSocketHandler.h \ wangle/channel/EventBaseHandler.h \ + wangle/channel/FileRegion.h \ wangle/channel/Handler.h \ wangle/channel/HandlerContext.h \ wangle/channel/HandlerContext-inl.h \ diff --git a/folly/io/async/AsyncSocket.cpp b/folly/io/async/AsyncSocket.cpp index 939788d7..35d29c39 100644 --- a/folly/io/async/AsyncSocket.cpp +++ b/folly/io/async/AsyncSocket.cpp @@ -17,6 +17,8 @@ #include #include +#include +#include #include #include @@ -24,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -43,7 +46,7 @@ const AsyncSocketException socketClosedLocallyEx( const AsyncSocketException socketShutdownForWritesEx( AsyncSocketException::END_OF_FILE, "socket shutdown for writes"); -// TODO: It might help performance to provide a version of WriteRequest that +// TODO: It might help performance to provide a version of BytesWriteRequest that // users could derive from, so we can avoid the extra allocation for each call // to write()/writev(). We could templatize TFramedAsyncChannel just like the // protocols are currently templatized for transports. @@ -52,53 +55,6 @@ const AsyncSocketException socketShutdownForWritesEx( // storage space, and only our internal version would allocate it at the end of // the WriteRequest. -/** - * A WriteRequest object tracks information about a pending write operation. - */ -class AsyncSocket::WriteRequest { - public: - WriteRequest(AsyncSocket* socket, - WriteRequest* next, - WriteCallback* callback, - uint32_t totalBytesWritten) : - socket_(socket), next_(next), callback_(callback), - totalBytesWritten_(totalBytesWritten) {} - - virtual void destroy() = 0; - - virtual bool performWrite() = 0; - - virtual void consume() = 0; - - virtual bool isComplete() = 0; - - WriteRequest* getNext() const { - return next_; - } - - WriteCallback* getCallback() const { - return callback_; - } - - uint32_t getTotalBytesWritten() const { - return totalBytesWritten_; - } - - void append(WriteRequest* next) { - assert(next_ == nullptr); - next_ = next; - } - - protected: - // protected destructor, to ensure callers use destroy() - virtual ~WriteRequest() {} - - AsyncSocket* socket_; ///< parent socket - WriteRequest* next_; ///< pointer to next WriteRequest - WriteCallback* callback_; ///< completion callback - uint32_t totalBytesWritten_; ///< total bytes written -}; - /* The default WriteRequest implementation, used for write(), writev() and * writeChain() * @@ -181,7 +137,7 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest { uint32_t bytesWritten, unique_ptr&& ioBuf, WriteFlags flags) - : AsyncSocket::WriteRequest(socket, nullptr, callback, 0) + : AsyncSocket::WriteRequest(socket, callback) , opCount_(opCount) , opIndex_(0) , flags_(flags) @@ -773,6 +729,17 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec, } } +void AsyncSocket::writeRequest(WriteRequest* req) { + if (writeReqTail_ == nullptr) { + assert(writeReqHead_ == nullptr); + writeReqHead_ = writeReqTail_ = req; + req->start(); + } else { + writeReqTail_->append(req); + writeReqTail_ = req; + } +} + void AsyncSocket::close() { VLOG(5) << "AsyncSocket::close(): this=" << this << ", fd_=" << fd_ << ", state=" << state_ << ", shutdownFlags=" diff --git a/folly/io/async/AsyncSocket.h b/folly/io/async/AsyncSocket.h index 9e3a224b..523093be 100644 --- a/folly/io/async/AsyncSocket.h +++ b/folly/io/async/AsyncSocket.h @@ -334,6 +334,12 @@ class AsyncSocket : virtual public AsyncTransportWrapper { std::unique_ptr&& buf, WriteFlags flags = WriteFlags::NONE) override; + class WriteRequest; + virtual void writeRequest(WriteRequest* req); + void writeRequestReady() { + handleWrite(); + } + // Methods inherited from AsyncTransport void close() override; void closeNow() override; @@ -477,6 +483,60 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ERROR }; + /** + * A WriteRequest object tracks information about a pending write operation. + */ + class WriteRequest { + public: + WriteRequest(AsyncSocket* socket, WriteCallback* callback) : + socket_(socket), callback_(callback) {} + + virtual void start() {}; + + virtual void destroy() = 0; + + virtual bool performWrite() = 0; + + virtual void consume() = 0; + + virtual bool isComplete() = 0; + + WriteRequest* getNext() const { + return next_; + } + + WriteCallback* getCallback() const { + return callback_; + } + + uint32_t getTotalBytesWritten() const { + return totalBytesWritten_; + } + + void append(WriteRequest* next) { + assert(next_ == nullptr); + next_ = next; + } + + void fail(const char* fn, const AsyncSocketException& ex) { + socket_->failWrite(fn, ex); + } + + void bytesWritten(size_t count) { + totalBytesWritten_ += count; + socket_->appBytesWritten_ += count; + } + + protected: + // protected destructor, to ensure callers use destroy() + virtual ~WriteRequest() {} + + AsyncSocket* socket_; ///< parent socket + WriteRequest* next_{nullptr}; ///< pointer to next WriteRequest + WriteCallback* callback_; ///< completion callback + uint32_t totalBytesWritten_{0}; ///< total bytes written + }; + protected: enum ReadResultEnum { READ_EOF = 0, @@ -516,7 +576,6 @@ class AsyncSocket : virtual public AsyncTransportWrapper { SHUT_READ = 0x04, }; - class WriteRequest; class BytesWriteRequest; class WriteTimeout : public AsyncTimeout { diff --git a/folly/io/async/test/AsyncSocketTest.h b/folly/io/async/test/AsyncSocketTest.h new file mode 100644 index 00000000..2c25d0e5 --- /dev/null +++ b/folly/io/async/test/AsyncSocketTest.h @@ -0,0 +1,265 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#include +#include + +// This is a test-only header +/* using override */ +using namespace folly; + +enum StateEnum { + STATE_WAITING, + STATE_SUCCEEDED, + STATE_FAILED +}; + +typedef std::function VoidCallback; + +class ConnCallback : public AsyncSocket::ConnectCallback { + public: + ConnCallback() + : state(STATE_WAITING) + , exception(AsyncSocketException::UNKNOWN, "none") {} + + void connectSuccess() noexcept override { + state = STATE_SUCCEEDED; + if (successCallback) { + successCallback(); + } + } + + void connectErr(const AsyncSocketException& ex) noexcept override { + state = STATE_FAILED; + exception = ex; + if (errorCallback) { + errorCallback(); + } + } + + StateEnum state; + AsyncSocketException exception; + VoidCallback successCallback; + VoidCallback errorCallback; +}; + +class WriteCallback : public AsyncTransportWrapper::WriteCallback { + public: + WriteCallback() + : state(STATE_WAITING) + , bytesWritten(0) + , exception(AsyncSocketException::UNKNOWN, "none") {} + + void writeSuccess() noexcept override { + state = STATE_SUCCEEDED; + if (successCallback) { + successCallback(); + } + } + + void writeErr(size_t bytesWritten, + const AsyncSocketException& ex) noexcept override { + state = STATE_FAILED; + this->bytesWritten = bytesWritten; + exception = ex; + if (errorCallback) { + errorCallback(); + } + } + + StateEnum state; + size_t bytesWritten; + AsyncSocketException exception; + VoidCallback successCallback; + VoidCallback errorCallback; +}; + +class ReadCallback : public AsyncTransportWrapper::ReadCallback { + public: + ReadCallback() + : state(STATE_WAITING) + , exception(AsyncSocketException::UNKNOWN, "none") + , buffers() {} + + ~ReadCallback() { + for (std::vector::iterator it = buffers.begin(); + it != buffers.end(); + ++it) { + it->free(); + } + currentBuffer.free(); + } + + void getReadBuffer(void** bufReturn, size_t* lenReturn) override { + if (!currentBuffer.buffer) { + currentBuffer.allocate(4096); + } + *bufReturn = currentBuffer.buffer; + *lenReturn = currentBuffer.length; + } + + void readDataAvailable(size_t len) noexcept override { + currentBuffer.length = len; + buffers.push_back(currentBuffer); + currentBuffer.reset(); + if (dataAvailableCallback) { + dataAvailableCallback(); + } + } + + void readEOF() noexcept override { + state = STATE_SUCCEEDED; + } + + void readErr(const AsyncSocketException& ex) noexcept override { + state = STATE_FAILED; + exception = ex; + } + + void verifyData(const char* expected, size_t expectedLen) const { + size_t offset = 0; + for (size_t idx = 0; idx < buffers.size(); ++idx) { + const auto& buf = buffers[idx]; + size_t cmpLen = std::min(buf.length, expectedLen - offset); + CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0); + CHECK_EQ(cmpLen, buf.length); + offset += cmpLen; + } + CHECK_EQ(offset, expectedLen); + } + + class Buffer { + public: + Buffer() : buffer(nullptr), length(0) {} + Buffer(char* buf, size_t len) : buffer(buf), length(len) {} + + void reset() { + buffer = nullptr; + length = 0; + } + void allocate(size_t length) { + assert(buffer == nullptr); + this->buffer = static_cast(malloc(length)); + this->length = length; + } + void free() { + ::free(buffer); + reset(); + } + + char* buffer; + size_t length; + }; + + StateEnum state; + AsyncSocketException exception; + std::vector buffers; + Buffer currentBuffer; + VoidCallback dataAvailableCallback; +}; + +class ReadVerifier { +}; + +class TestServer { + public: + // Create a TestServer. + // This immediately starts listening on an ephemeral port. + TestServer() + : fd_(-1) { + fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP); + if (fd_ < 0) { + throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, + "failed to create test server socket", errno); + } + if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) { + throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, + "failed to put test server socket in " + "non-blocking mode", errno); + } + if (listen(fd_, 10) != 0) { + throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, + "failed to listen on test server socket", + errno); + } + + address_.setFromLocalAddress(fd_); + // The local address will contain 0.0.0.0. + // Change it to 127.0.0.1, so it can be used to connect to the server + address_.setFromIpPort("127.0.0.1", address_.getPort()); + } + + // Get the address for connecting to the server + const folly::SocketAddress& getAddress() const { + return address_; + } + + int acceptFD(int timeout=50) { + struct pollfd pfd; + pfd.fd = fd_; + pfd.events = POLLIN; + int ret = poll(&pfd, 1, timeout); + if (ret == 0) { + throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, + "test server accept() timed out"); + } else if (ret < 0) { + throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, + "test server accept() poll failed", errno); + } + + int acceptedFd = ::accept(fd_, nullptr, nullptr); + if (acceptedFd < 0) { + throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, + "test server accept() failed", errno); + } + + return acceptedFd; + } + + std::shared_ptr accept(int timeout=50) { + int fd = acceptFD(timeout); + return std::shared_ptr(new BlockingSocket(fd)); + } + + std::shared_ptr acceptAsync(EventBase* evb, int timeout=50) { + int fd = acceptFD(timeout); + return AsyncSocket::newSocket(evb, fd); + } + + /** + * Accept a connection, read data from it, and verify that it matches the + * data in the specified buffer. + */ + void verifyConnection(const char* buf, size_t len) { + // accept a connection + std::shared_ptr acceptedSocket = accept(); + // read the data and compare it to the specified buffer + boost::scoped_array readbuf(new uint8_t[len]); + acceptedSocket->readAll(readbuf.get(), len); + CHECK_EQ(memcmp(buf, readbuf.get(), len), 0); + // make sure we get EOF next + uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len); + CHECK_EQ(bytesRead, 0); + } + + private: + int fd_; + folly::SocketAddress address_; +}; diff --git a/folly/io/async/test/AsyncSocketTest2.cpp b/folly/io/async/test/AsyncSocketTest2.cpp index 147bec94..ca075164 100644 --- a/folly/io/async/test/AsyncSocketTest2.cpp +++ b/folly/io/async/test/AsyncSocketTest2.cpp @@ -20,7 +20,7 @@ #include #include -#include +#include #include #include @@ -47,246 +47,6 @@ using boost::scoped_array; using namespace folly; -enum StateEnum { - STATE_WAITING, - STATE_SUCCEEDED, - STATE_FAILED -}; - -typedef std::function VoidCallback; - - -class ConnCallback : public AsyncSocket::ConnectCallback { - public: - ConnCallback() - : state(STATE_WAITING) - , exception(AsyncSocketException::UNKNOWN, "none") {} - - void connectSuccess() noexcept override { - state = STATE_SUCCEEDED; - if (successCallback) { - successCallback(); - } - } - - void connectErr(const AsyncSocketException& ex) noexcept override { - state = STATE_FAILED; - exception = ex; - if (errorCallback) { - errorCallback(); - } - } - - StateEnum state; - AsyncSocketException exception; - VoidCallback successCallback; - VoidCallback errorCallback; -}; - -class WriteCallback : public AsyncTransportWrapper::WriteCallback { - public: - WriteCallback() - : state(STATE_WAITING) - , bytesWritten(0) - , exception(AsyncSocketException::UNKNOWN, "none") {} - - void writeSuccess() noexcept override { - state = STATE_SUCCEEDED; - if (successCallback) { - successCallback(); - } - } - - void writeErr(size_t bytesWritten, - const AsyncSocketException& ex) noexcept override { - state = STATE_FAILED; - this->bytesWritten = bytesWritten; - exception = ex; - if (errorCallback) { - errorCallback(); - } - } - - StateEnum state; - size_t bytesWritten; - AsyncSocketException exception; - VoidCallback successCallback; - VoidCallback errorCallback; -}; - -class ReadCallback : public AsyncTransportWrapper::ReadCallback { - public: - ReadCallback() - : state(STATE_WAITING) - , exception(AsyncSocketException::UNKNOWN, "none") - , buffers() {} - - ~ReadCallback() { - for (vector::iterator it = buffers.begin(); - it != buffers.end(); - ++it) { - it->free(); - } - currentBuffer.free(); - } - - void getReadBuffer(void** bufReturn, size_t* lenReturn) override { - if (!currentBuffer.buffer) { - currentBuffer.allocate(4096); - } - *bufReturn = currentBuffer.buffer; - *lenReturn = currentBuffer.length; - } - - void readDataAvailable(size_t len) noexcept override { - currentBuffer.length = len; - buffers.push_back(currentBuffer); - currentBuffer.reset(); - if (dataAvailableCallback) { - dataAvailableCallback(); - } - } - - void readEOF() noexcept override { - state = STATE_SUCCEEDED; - } - - void readErr(const AsyncSocketException& ex) noexcept override { - state = STATE_FAILED; - exception = ex; - } - - void verifyData(const char* expected, size_t expectedLen) const { - size_t offset = 0; - for (size_t idx = 0; idx < buffers.size(); ++idx) { - const auto& buf = buffers[idx]; - size_t cmpLen = std::min(buf.length, expectedLen - offset); - CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0); - CHECK_EQ(cmpLen, buf.length); - offset += cmpLen; - } - CHECK_EQ(offset, expectedLen); - } - - class Buffer { - public: - Buffer() : buffer(nullptr), length(0) {} - Buffer(char* buf, size_t len) : buffer(buf), length(len) {} - - void reset() { - buffer = nullptr; - length = 0; - } - void allocate(size_t length) { - assert(buffer == nullptr); - this->buffer = static_cast(malloc(length)); - this->length = length; - } - void free() { - ::free(buffer); - reset(); - } - - char* buffer; - size_t length; - }; - - StateEnum state; - AsyncSocketException exception; - vector buffers; - Buffer currentBuffer; - VoidCallback dataAvailableCallback; -}; - -class ReadVerifier { -}; - -class TestServer { - public: - // Create a TestServer. - // This immediately starts listening on an ephemeral port. - TestServer() - : fd_(-1) { - fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP); - if (fd_ < 0) { - throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, - "failed to create test server socket", errno); - } - if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) { - throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, - "failed to put test server socket in " - "non-blocking mode", errno); - } - if (listen(fd_, 10) != 0) { - throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, - "failed to listen on test server socket", - errno); - } - - address_.setFromLocalAddress(fd_); - // The local address will contain 0.0.0.0. - // Change it to 127.0.0.1, so it can be used to connect to the server - address_.setFromIpPort("127.0.0.1", address_.getPort()); - } - - // Get the address for connecting to the server - const folly::SocketAddress& getAddress() const { - return address_; - } - - int acceptFD(int timeout=50) { - struct pollfd pfd; - pfd.fd = fd_; - pfd.events = POLLIN; - int ret = poll(&pfd, 1, timeout); - if (ret == 0) { - throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, - "test server accept() timed out"); - } else if (ret < 0) { - throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, - "test server accept() poll failed", errno); - } - - int acceptedFd = ::accept(fd_, nullptr, nullptr); - if (acceptedFd < 0) { - throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR, - "test server accept() failed", errno); - } - - return acceptedFd; - } - - std::shared_ptr accept(int timeout=50) { - int fd = acceptFD(timeout); - return std::shared_ptr(new BlockingSocket(fd)); - } - - std::shared_ptr acceptAsync(EventBase* evb, int timeout=50) { - int fd = acceptFD(timeout); - return AsyncSocket::newSocket(evb, fd); - } - - /** - * Accept a connection, read data from it, and verify that it matches the - * data in the specified buffer. - */ - void verifyConnection(const char* buf, size_t len) { - // accept a connection - std::shared_ptr acceptedSocket = accept(); - // read the data and compare it to the specified buffer - scoped_array readbuf(new uint8_t[len]); - acceptedSocket->readAll(readbuf.get(), len); - CHECK_EQ(memcmp(buf, readbuf.get(), len), 0); - // make sure we get EOF next - uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len); - CHECK_EQ(bytesRead, 0); - } - - private: - int fd_; - folly::SocketAddress address_; -}; - class DelayedWrite: public AsyncTimeout { public: DelayedWrite(const std::shared_ptr& socket, diff --git a/folly/wangle/channel/FileRegion.cpp b/folly/wangle/channel/FileRegion.cpp new file mode 100644 index 00000000..7d14a4af --- /dev/null +++ b/folly/wangle/channel/FileRegion.cpp @@ -0,0 +1,214 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +using namespace folly; +using namespace folly::wangle; + +namespace { + +struct FileRegionReadPool {}; + +Singleton readPool( + []{ + return new IOThreadPoolExecutor( + sysconf(_SC_NPROCESSORS_ONLN), + std::make_shared("FileRegionReadPool")); + }); + +} + +namespace folly { namespace wangle { + +FileRegion::FileWriteRequest::FileWriteRequest(AsyncSocket* socket, + WriteCallback* callback, int fd, off_t offset, size_t count) + : WriteRequest(socket, callback), + readFd_(fd), offset_(offset), count_(count) { +} + +void FileRegion::FileWriteRequest::destroy() { + readBase_->runInEventBaseThread([this]{ + delete this; + }); +} + +bool FileRegion::FileWriteRequest::performWrite() { + if (!started_) { + start(); + return true; + } + + int flags = SPLICE_F_NONBLOCK | SPLICE_F_MORE; + ssize_t spliced = ::splice(pipe_out_, nullptr, + socket_->getFd(), nullptr, + bytesInPipe_, flags); + if (spliced == -1) { + if (errno == EAGAIN) { + return true; + } + return false; + } + + bytesInPipe_ -= spliced; + bytesWritten(spliced); + return true; +} + +void FileRegion::FileWriteRequest::consume() { + // do nothing +} + +bool FileRegion::FileWriteRequest::isComplete() { + return totalBytesWritten_ == count_; +} + +void FileRegion::FileWriteRequest::messageAvailable(size_t&& count) { + bool shouldWrite = bytesInPipe_ == 0; + bytesInPipe_ += count; + if (shouldWrite) { + socket_->writeRequestReady(); + } +} + +#ifdef __GLIBC__ +# if (__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 9)) +# define GLIBC_AT_LEAST_2_9 1 +# endif +#endif + +void FileRegion::FileWriteRequest::start() { + started_ = true; + readBase_ = readPool.get()->getEventBase(); + readBase_->runInEventBaseThread([this]{ + auto flags = fcntl(readFd_, F_GETFL); + if (flags == -1) { + fail(__func__, AsyncSocketException( + AsyncSocketException::INTERNAL_ERROR, + "fcntl F_GETFL failed", errno)); + return; + } + + flags &= O_ACCMODE; + if (flags == O_WRONLY) { + fail(__func__, AsyncSocketException( + AsyncSocketException::BAD_ARGS, "file not open for reading")); + return; + } + +#ifndef GLIBC_AT_LEAST_2_9 + fail(__func__, AsyncSocketException( + AsyncSocketException::NOT_SUPPORTED, + "writeFile unsupported on glibc < 2.9")); + return; +#else + int pipeFds[2]; + if (::pipe2(pipeFds, O_NONBLOCK) == -1) { + fail(__func__, AsyncSocketException( + AsyncSocketException::INTERNAL_ERROR, + "pipe2 failed", errno)); + return; + } + + // Max size for unprevileged processes as set in /proc/sys/fs/pipe-max-size + // Ignore failures and just roll with it + // TODO maybe read max size from /proc? + fcntl(pipeFds[0], F_SETPIPE_SZ, 1048576); + fcntl(pipeFds[1], F_SETPIPE_SZ, 1048576); + + pipe_out_ = pipeFds[0]; + + socket_->getEventBase()->runInEventBaseThreadAndWait([&]{ + startConsuming(socket_->getEventBase(), &queue_); + }); + readHandler_ = folly::make_unique( + this, pipeFds[1], count_); +#endif + }); +} + +FileRegion::FileWriteRequest::~FileWriteRequest() { + CHECK(readBase_->isInEventBaseThread()); + socket_->getEventBase()->runInEventBaseThreadAndWait([&]{ + stopConsuming(); + if (pipe_out_ > -1) { + ::close(pipe_out_); + } + }); + +} + +void FileRegion::FileWriteRequest::fail( + const char* fn, + const AsyncSocketException& ex) { + socket_->getEventBase()->runInEventBaseThread([=]{ + WriteRequest::fail(fn, ex); + }); +} + +FileRegion::FileWriteRequest::FileReadHandler::FileReadHandler( + FileWriteRequest* req, int pipe_in, size_t bytesToRead) + : req_(req), pipe_in_(pipe_in), bytesToRead_(bytesToRead) { + CHECK(req_->readBase_->isInEventBaseThread()); + initHandler(req_->readBase_, pipe_in); + if (!registerHandler(EventFlags::WRITE | EventFlags::PERSIST)) { + req_->fail(__func__, AsyncSocketException( + AsyncSocketException::INTERNAL_ERROR, + "registerHandler failed")); + } +} + +FileRegion::FileWriteRequest::FileReadHandler::~FileReadHandler() { + CHECK(req_->readBase_->isInEventBaseThread()); + unregisterHandler(); + ::close(pipe_in_); +} + +void FileRegion::FileWriteRequest::FileReadHandler::handlerReady( + uint16_t events) noexcept { + CHECK(events & EventHandler::WRITE); + if (bytesToRead_ == 0) { + unregisterHandler(); + return; + } + + int flags = SPLICE_F_NONBLOCK | SPLICE_F_MORE; + ssize_t spliced = ::splice(req_->readFd_, &req_->offset_, + pipe_in_, nullptr, + bytesToRead_, flags); + if (spliced == -1) { + if (errno == EAGAIN) { + return; + } else { + req_->fail(__func__, AsyncSocketException( + AsyncSocketException::INTERNAL_ERROR, + "splice failed", errno)); + return; + } + } + + if (spliced > 0) { + bytesToRead_ -= spliced; + try { + req_->queue_.putMessage(static_cast(spliced)); + } catch (...) { + req_->fail(__func__, AsyncSocketException( + AsyncSocketException::INTERNAL_ERROR, + "putMessage failed")); + return; + } + } +} +}} // folly::wangle diff --git a/folly/wangle/channel/FileRegion.h b/folly/wangle/channel/FileRegion.h new file mode 100644 index 00000000..6360ae35 --- /dev/null +++ b/folly/wangle/channel/FileRegion.h @@ -0,0 +1,116 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace folly { namespace wangle { + +class FileRegion { + public: + FileRegion(int fd, off_t offset, size_t count) + : fd_(fd), offset_(offset), count_(count) {} + + Future transferTo(std::shared_ptr transport) { + auto socket = std::dynamic_pointer_cast( + transport); + CHECK(socket); + auto cb = new WriteCallback(); + auto f = cb->promise_.getFuture(); + auto req = new FileWriteRequest(socket.get(), cb, fd_, offset_, count_); + socket->writeRequest(req); + return f; + } + + private: + class WriteCallback : private AsyncSocket::WriteCallback { + void writeSuccess() noexcept override { + promise_.setValue(); + delete this; + } + + void writeErr(size_t bytesWritten, + const AsyncSocketException& ex) + noexcept override { + promise_.setException(ex); + delete this; + } + + friend class FileRegion; + folly::Promise promise_; + }; + + const int fd_; + const off_t offset_; + const size_t count_; + + class FileWriteRequest : public AsyncSocket::WriteRequest, + public NotificationQueue::Consumer { + public: + FileWriteRequest(AsyncSocket* socket, WriteCallback* callback, + int fd, off_t offset, size_t count); + + void destroy() override; + + bool performWrite() override; + + void consume() override; + + bool isComplete() override; + + void messageAvailable(size_t&& count) override; + + void start() override; + + class FileReadHandler : public folly::EventHandler { + public: + FileReadHandler(FileWriteRequest* req, int pipe_in, size_t bytesToRead); + + ~FileReadHandler(); + + void handlerReady(uint16_t events) noexcept override; + + private: + FileWriteRequest* req_; + int pipe_in_; + size_t bytesToRead_; + }; + + private: + ~FileWriteRequest(); + + void fail(const char* fn, const AsyncSocketException& ex); + + const int readFd_; + off_t offset_; + const size_t count_; + bool started_{false}; + int pipe_out_{-1}; + + size_t bytesInPipe_{0}; + folly::EventBase* readBase_; + folly::NotificationQueue queue_; + std::unique_ptr readHandler_; + }; +}; + +}} // folly::wangle diff --git a/folly/wangle/channel/test/FileRegionTest.cpp b/folly/wangle/channel/test/FileRegionTest.cpp new file mode 100644 index 00000000..ff12fc2f --- /dev/null +++ b/folly/wangle/channel/test/FileRegionTest.cpp @@ -0,0 +1,110 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +using namespace folly; +using namespace folly::wangle; +using namespace testing; + +struct FileRegionTest : public Test { + FileRegionTest() { + // Connect + socket = AsyncSocket::newSocket(&evb); + socket->connect(&ccb, server.getAddress(), 30); + + // Accept the connection + acceptedSocket = server.acceptAsync(&evb); + acceptedSocket->setReadCB(&rcb); + + // Create temp file + char path[] = "/tmp/AsyncSocketTest.WriteFile.XXXXXX"; + fd = mkostemp(path, O_RDWR); + EXPECT_TRUE(fd > 0); + EXPECT_EQ(0, unlink(path)); + } + + ~FileRegionTest() { + // Close up shop + close(fd); + acceptedSocket->close(); + socket->close(); + } + + TestServer server; + EventBase evb; + std::shared_ptr socket; + std::shared_ptr acceptedSocket; + ConnCallback ccb; + ReadCallback rcb; + int fd; +}; + +TEST_F(FileRegionTest, Basic) { + size_t count = 1000000000; // 1 GB + void* zeroBuf = calloc(1, count); + write(fd, zeroBuf, count); + + FileRegion fileRegion(fd, 0, count); + auto f = fileRegion.transferTo(socket); + try { + f.getVia(&evb); + } catch (std::exception& e) { + LOG(FATAL) << exceptionStr(e); + } + + // Let the reads run to completion + socket->shutdownWrite(); + evb.loop(); + + ASSERT_EQ(rcb.state, STATE_SUCCEEDED); + + size_t receivedBytes = 0; + for (auto& buf : rcb.buffers) { + receivedBytes += buf.length; + ASSERT_EQ(memcmp(buf.buffer, zeroBuf, buf.length), 0); + } + ASSERT_EQ(receivedBytes, count); +} + +TEST_F(FileRegionTest, Repeated) { + size_t count = 1000000; + void* zeroBuf = calloc(1, count); + write(fd, zeroBuf, count); + + int sendCount = 1000; + + FileRegion fileRegion(fd, 0, count); + std::vector> fs; + for (int i = 0; i < sendCount; i++) { + fs.push_back(fileRegion.transferTo(socket)); + } + auto f = collect(fs); + ASSERT_NO_THROW(f.getVia(&evb)); + + // Let the reads run to completion + socket->shutdownWrite(); + evb.loop(); + + ASSERT_EQ(rcb.state, STATE_SUCCEEDED); + + size_t receivedBytes = 0; + for (auto& buf : rcb.buffers) { + receivedBytes += buf.length; + } + ASSERT_EQ(receivedBytes, sendCount*count); +}