From 48a8ecdb0d87c31e45520fbcad1f032890c67230 Mon Sep 17 00:00:00 2001 From: James Sedgwick Date: Wed, 20 May 2015 08:34:26 -0700 Subject: [PATCH] make AsyncSocket::WriteRequest an interface Summary: This will allow a subsequent diff to implement file transfers as another type of write request Test Plan: unit Reviewed By: davejwatson@fb.com Subscribers: net-systems@, folly-diffs@, yfeldblum, chalfant, fugalh, bmatheny FB internal diff: D2080257 Signature: t1:2080257:1432044566:bcc0724d349879f46e3e58ee672aff7bf37fa5f6 --- folly/io/async/AsyncSocket.cpp | 210 +++++++++++++++++++-------------- folly/io/async/AsyncSocket.h | 1 + 2 files changed, 124 insertions(+), 87 deletions(-) diff --git a/folly/io/async/AsyncSocket.cpp b/folly/io/async/AsyncSocket.cpp index 01ed6621..f477d5da 100644 --- a/folly/io/async/AsyncSocket.cpp +++ b/folly/io/async/AsyncSocket.cpp @@ -53,44 +53,24 @@ const AsyncSocketException socketShutdownForWritesEx( // the WriteRequest. /** - * A WriteRequest object tracks information about a pending write() or writev() - * operation. - * - * A new WriteRequest operation is allocated on the heap for all write - * operations that cannot be completed immediately. + * A WriteRequest object tracks information about a pending write operation. */ class AsyncSocket::WriteRequest { public: - static WriteRequest* newRequest(WriteCallback* callback, - const iovec* ops, - uint32_t opCount, - unique_ptr&& ioBuf, - WriteFlags flags) { - assert(opCount > 0); - // Since we put a variable size iovec array at the end - // of each WriteRequest, we have to manually allocate the memory. - void* buf = malloc(sizeof(WriteRequest) + - (opCount * sizeof(struct iovec))); - if (buf == nullptr) { - throw std::bad_alloc(); - } + WriteRequest(AsyncSocket* socket, + WriteRequest* next, + WriteCallback* callback, + uint32_t totalBytesWritten) : + socket_(socket), next_(next), callback_(callback), + totalBytesWritten_(totalBytesWritten) {} - return new(buf) WriteRequest(callback, ops, opCount, std::move(ioBuf), - flags); - } + virtual void destroy() = 0; - void destroy() { - this->~WriteRequest(); - free(this); - } + virtual bool performWrite() = 0; - bool cork() const { - return isSet(flags_, WriteFlags::CORK); - } + virtual void consume() = 0; - WriteFlags flags() const { - return flags_; - } + virtual bool isComplete() = 0; WriteRequest* getNext() const { return next_; @@ -100,76 +80,141 @@ class AsyncSocket::WriteRequest { return callback_; } - uint32_t getBytesWritten() const { - return bytesWritten_; + uint32_t getTotalBytesWritten() const { + return totalBytesWritten_; } - const struct iovec* getOps() const { - assert(opCount_ > opIndex_); - return writeOps_ + opIndex_; + void append(WriteRequest* next) { + assert(next_ == nullptr); + next_ = next; } - uint32_t getOpCount() const { - assert(opCount_ > opIndex_); - return opCount_ - opIndex_; + protected: + // protected destructor, to ensure callers use destroy() + virtual ~WriteRequest() {} + + AsyncSocket* socket_; ///< parent socket + WriteRequest* next_; ///< pointer to next WriteRequest + WriteCallback* callback_; ///< completion callback + uint32_t totalBytesWritten_; ///< total bytes written +}; + +/* The default WriteRequest implementation, used for write(), writev() and + * writeChain() + * + * A new BytesWriteRequest operation is allocated on the heap for all write + * operations that cannot be completed immediately. + */ +class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest { + public: + static BytesWriteRequest* newRequest(AsyncSocket* socket, + WriteCallback* callback, + const iovec* ops, + uint32_t opCount, + uint32_t partialWritten, + uint32_t bytesWritten, + unique_ptr&& ioBuf, + WriteFlags flags) { + assert(opCount > 0); + // Since we put a variable size iovec array at the end + // of each BytesWriteRequest, we have to manually allocate the memory. + void* buf = malloc(sizeof(BytesWriteRequest) + + (opCount * sizeof(struct iovec))); + if (buf == nullptr) { + throw std::bad_alloc(); + } + + return new(buf) BytesWriteRequest(socket, callback, ops, opCount, + partialWritten, bytesWritten, + std::move(ioBuf), flags); } - void consume(uint32_t wholeOps, uint32_t partialBytes, - uint32_t totalBytesWritten) { - // Advance opIndex_ forward by wholeOps - opIndex_ += wholeOps; + void destroy() override { + this->~BytesWriteRequest(); + free(this); + } + + bool performWrite() override { + WriteFlags writeFlags = flags_; + if (getNext() != nullptr) { + writeFlags = writeFlags | WriteFlags::CORK; + } + bytesWritten_ = socket_->performWrite(getOps(), getOpCount(), writeFlags, + &opsWritten_, &partialBytes_); + return bytesWritten_ >= 0; + } + + bool isComplete() override { + return opsWritten_ == getOpCount(); + } + + void consume() override { + // Advance opIndex_ forward by opsWritten_ + opIndex_ += opsWritten_; assert(opIndex_ < opCount_); // If we've finished writing any IOBufs, release them if (ioBuf_) { - for (uint32_t i = wholeOps; i != 0; --i) { + for (uint32_t i = opsWritten_; i != 0; --i) { assert(ioBuf_); ioBuf_ = ioBuf_->pop(); } } - // Move partialBytes forward into the current iovec buffer + // Move partialBytes_ forward into the current iovec buffer struct iovec* currentOp = writeOps_ + opIndex_; - assert((partialBytes < currentOp->iov_len) || (currentOp->iov_len == 0)); + assert((partialBytes_ < currentOp->iov_len) || (currentOp->iov_len == 0)); currentOp->iov_base = - reinterpret_cast(currentOp->iov_base) + partialBytes; - currentOp->iov_len -= partialBytes; + reinterpret_cast(currentOp->iov_base) + partialBytes_; + currentOp->iov_len -= partialBytes_; - // Increment the bytesWritten_ count by totalBytesWritten - bytesWritten_ += totalBytesWritten; - } - - void append(WriteRequest* next) { - assert(next_ == nullptr); - next_ = next; + // Increment the totalBytesWritten_ count by bytesWritten_; + totalBytesWritten_ += bytesWritten_; } private: - WriteRequest(WriteCallback* callback, - const struct iovec* ops, - uint32_t opCount, - unique_ptr&& ioBuf, - WriteFlags flags) - : next_(nullptr) - , callback_(callback) - , bytesWritten_(0) + BytesWriteRequest(AsyncSocket* socket, + WriteCallback* callback, + const struct iovec* ops, + uint32_t opCount, + uint32_t partialBytes, + uint32_t bytesWritten, + unique_ptr&& ioBuf, + WriteFlags flags) + : AsyncSocket::WriteRequest(socket, nullptr, callback, 0) , opCount_(opCount) , opIndex_(0) , flags_(flags) - , ioBuf_(std::move(ioBuf)) { + , ioBuf_(std::move(ioBuf)) + , opsWritten_(0) + , partialBytes_(partialBytes) + , bytesWritten_(bytesWritten) { memcpy(writeOps_, ops, sizeof(*ops) * opCount_); } - // Private destructor, to ensure callers use destroy() - ~WriteRequest() {} + // private destructor, to ensure callers use destroy() + virtual ~BytesWriteRequest() {} + + const struct iovec* getOps() const { + assert(opCount_ > opIndex_); + return writeOps_ + opIndex_; + } + + uint32_t getOpCount() const { + assert(opCount_ > opIndex_); + return opCount_ - opIndex_; + } - WriteRequest* next_; ///< pointer to next WriteRequest - WriteCallback* callback_; ///< completion callback - uint32_t bytesWritten_; ///< bytes written uint32_t opCount_; ///< number of entries in writeOps_ uint32_t opIndex_; ///< current index into writeOps_ WriteFlags flags_; ///< set for WriteFlags unique_ptr ioBuf_; ///< underlying IOBuf, or nullptr if N/A + + // for consume(), how much we wrote on the last write + uint32_t opsWritten_; ///< complete ops written + uint32_t partialBytes_; ///< partial bytes of incomplete op written + ssize_t bytesWritten_; ///< bytes written altogether + struct iovec writeOps_[]; ///< write operation(s) list }; @@ -687,16 +732,16 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec, // Create a new WriteRequest to add to the queue WriteRequest* req; try { - req = WriteRequest::newRequest(callback, vec + countWritten, - count - countWritten, std::move(ioBuf), - flags); + req = BytesWriteRequest::newRequest(this, callback, vec + countWritten, + count - countWritten, partialWritten, + bytesWritten, std::move(ioBuf), flags); } catch (const std::exception& ex) { // we mainly expect to catch std::bad_alloc here AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR, withAddr(string("failed to append new WriteRequest: ") + ex.what())); return failWrite(__func__, callback, bytesWritten, tex); } - req->consume(0, partialWritten, bytesWritten); + req->consume(); if (writeReqTail_ == nullptr) { assert(writeReqHead_ == nullptr); writeReqHead_ = writeReqTail_ = req; @@ -1346,20 +1391,11 @@ void AsyncSocket::handleWrite() noexcept { // (See the comment in handleRead() explaining how this can happen.) EventBase* originalEventBase = eventBase_; while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) { - uint32_t countWritten; - uint32_t partialWritten; - WriteFlags writeFlags = writeReqHead_->flags(); - if (writeReqHead_->getNext() != nullptr) { - writeFlags = writeFlags | WriteFlags::CORK; - } - int bytesWritten = performWrite(writeReqHead_->getOps(), - writeReqHead_->getOpCount(), - writeFlags, &countWritten, &partialWritten); - if (bytesWritten < 0) { + if (!writeReqHead_->performWrite()) { AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR, withAddr("writev() failed"), errno); return failWrite(__func__, ex); - } else if (countWritten == writeReqHead_->getOpCount()) { + } else if (writeReqHead_->isComplete()) { // We finished this request WriteRequest* req = writeReqHead_; writeReqHead_ = req->getNext(); @@ -1424,7 +1460,7 @@ void AsyncSocket::handleWrite() noexcept { // We'll continue around the loop, trying to write another request } else { // Partial write. - writeReqHead_->consume(countWritten, partialWritten, bytesWritten); + writeReqHead_->consume(); // Stop after a partial write; it's highly likely that a subsequent write // attempt will just return EAGAIN. // @@ -1822,7 +1858,7 @@ void AsyncSocket::failWrite(const char* fn, const AsyncSocketException& ex) { WriteRequest* req = writeReqHead_; writeReqHead_ = req->getNext(); WriteCallback* callback = req->getCallback(); - uint32_t bytesWritten = req->getBytesWritten(); + uint32_t bytesWritten = req->getTotalBytesWritten(); req->destroy(); if (callback) { callback->writeErr(bytesWritten, ex); @@ -1859,7 +1895,7 @@ void AsyncSocket::failAllWrites(const AsyncSocketException& ex) { writeReqHead_ = req->getNext(); WriteCallback* callback = req->getCallback(); if (callback) { - callback->writeErr(req->getBytesWritten(), ex); + callback->writeErr(req->getTotalBytesWritten(), ex); } req->destroy(); } diff --git a/folly/io/async/AsyncSocket.h b/folly/io/async/AsyncSocket.h index e6209166..866c5d91 100644 --- a/folly/io/async/AsyncSocket.h +++ b/folly/io/async/AsyncSocket.h @@ -517,6 +517,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { }; class WriteRequest; + class BytesWriteRequest; class WriteTimeout : public AsyncTimeout { public: -- 2.34.1