Invoking correct callback during TFO fallback
[folly.git] / folly / io / async / AsyncSocket.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/AsyncSocket.h>
18
19 #include <folly/ExceptionWrapper.h>
20 #include <folly/SocketAddress.h>
21 #include <folly/io/IOBuf.h>
22 #include <folly/Portability.h>
23 #include <folly/portability/Fcntl.h>
24 #include <folly/portability/Sockets.h>
25 #include <folly/portability/SysUio.h>
26 #include <folly/portability/Unistd.h>
27
28 #include <errno.h>
29 #include <limits.h>
30 #include <thread>
31 #include <sys/types.h>
32 #include <boost/preprocessor/control/if.hpp>
33
34 using std::string;
35 using std::unique_ptr;
36
37 namespace fsp = folly::portability::sockets;
38
39 namespace folly {
40
41 // static members initializers
42 const AsyncSocket::OptionMap AsyncSocket::emptyOptionMap;
43
44 const AsyncSocketException socketClosedLocallyEx(
45     AsyncSocketException::END_OF_FILE, "socket closed locally");
46 const AsyncSocketException socketShutdownForWritesEx(
47     AsyncSocketException::END_OF_FILE, "socket shutdown for writes");
48
49 // TODO: It might help performance to provide a version of BytesWriteRequest that
50 // users could derive from, so we can avoid the extra allocation for each call
51 // to write()/writev().  We could templatize TFramedAsyncChannel just like the
52 // protocols are currently templatized for transports.
53 //
54 // We would need the version for external users where they provide the iovec
55 // storage space, and only our internal version would allocate it at the end of
56 // the WriteRequest.
57
58 /* The default WriteRequest implementation, used for write(), writev() and
59  * writeChain()
60  *
61  * A new BytesWriteRequest operation is allocated on the heap for all write
62  * operations that cannot be completed immediately.
63  */
64 class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
65  public:
66   static BytesWriteRequest* newRequest(AsyncSocket* socket,
67                                        WriteCallback* callback,
68                                        const iovec* ops,
69                                        uint32_t opCount,
70                                        uint32_t partialWritten,
71                                        uint32_t bytesWritten,
72                                        unique_ptr<IOBuf>&& ioBuf,
73                                        WriteFlags flags) {
74     assert(opCount > 0);
75     // Since we put a variable size iovec array at the end
76     // of each BytesWriteRequest, we have to manually allocate the memory.
77     void* buf = malloc(sizeof(BytesWriteRequest) +
78                        (opCount * sizeof(struct iovec)));
79     if (buf == nullptr) {
80       throw std::bad_alloc();
81     }
82
83     return new(buf) BytesWriteRequest(socket, callback, ops, opCount,
84                                       partialWritten, bytesWritten,
85                                       std::move(ioBuf), flags);
86   }
87
88   void destroy() override {
89     this->~BytesWriteRequest();
90     free(this);
91   }
92
93   WriteResult performWrite() override {
94     WriteFlags writeFlags = flags_;
95     if (getNext() != nullptr) {
96       writeFlags = writeFlags | WriteFlags::CORK;
97     }
98     return socket_->performWrite(
99         getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_);
100   }
101
102   bool isComplete() override {
103     return opsWritten_ == getOpCount();
104   }
105
106   void consume() override {
107     // Advance opIndex_ forward by opsWritten_
108     opIndex_ += opsWritten_;
109     assert(opIndex_ < opCount_);
110
111     // If we've finished writing any IOBufs, release them
112     if (ioBuf_) {
113       for (uint32_t i = opsWritten_; i != 0; --i) {
114         assert(ioBuf_);
115         ioBuf_ = ioBuf_->pop();
116       }
117     }
118
119     // Move partialBytes_ forward into the current iovec buffer
120     struct iovec* currentOp = writeOps_ + opIndex_;
121     assert((partialBytes_ < currentOp->iov_len) || (currentOp->iov_len == 0));
122     currentOp->iov_base =
123       reinterpret_cast<uint8_t*>(currentOp->iov_base) + partialBytes_;
124     currentOp->iov_len -= partialBytes_;
125
126     // Increment the totalBytesWritten_ count by bytesWritten_;
127     totalBytesWritten_ += bytesWritten_;
128   }
129
130  private:
131   BytesWriteRequest(AsyncSocket* socket,
132                     WriteCallback* callback,
133                     const struct iovec* ops,
134                     uint32_t opCount,
135                     uint32_t partialBytes,
136                     uint32_t bytesWritten,
137                     unique_ptr<IOBuf>&& ioBuf,
138                     WriteFlags flags)
139     : AsyncSocket::WriteRequest(socket, callback)
140     , opCount_(opCount)
141     , opIndex_(0)
142     , flags_(flags)
143     , ioBuf_(std::move(ioBuf))
144     , opsWritten_(0)
145     , partialBytes_(partialBytes)
146     , bytesWritten_(bytesWritten) {
147     memcpy(writeOps_, ops, sizeof(*ops) * opCount_);
148   }
149
150   // private destructor, to ensure callers use destroy()
151   ~BytesWriteRequest() override = default;
152
153   const struct iovec* getOps() const {
154     assert(opCount_ > opIndex_);
155     return writeOps_ + opIndex_;
156   }
157
158   uint32_t getOpCount() const {
159     assert(opCount_ > opIndex_);
160     return opCount_ - opIndex_;
161   }
162
163   uint32_t opCount_;            ///< number of entries in writeOps_
164   uint32_t opIndex_;            ///< current index into writeOps_
165   WriteFlags flags_;            ///< set for WriteFlags
166   unique_ptr<IOBuf> ioBuf_;     ///< underlying IOBuf, or nullptr if N/A
167
168   // for consume(), how much we wrote on the last write
169   uint32_t opsWritten_;         ///< complete ops written
170   uint32_t partialBytes_;       ///< partial bytes of incomplete op written
171   ssize_t bytesWritten_;        ///< bytes written altogether
172
173   struct iovec writeOps_[];     ///< write operation(s) list
174 };
175
176 AsyncSocket::AsyncSocket()
177     : eventBase_(nullptr),
178       writeTimeout_(this, nullptr),
179       ioHandler_(this, nullptr),
180       immediateReadHandler_(this) {
181   VLOG(5) << "new AsyncSocket()";
182   init();
183 }
184
185 AsyncSocket::AsyncSocket(EventBase* evb)
186     : eventBase_(evb),
187       writeTimeout_(this, evb),
188       ioHandler_(this, evb),
189       immediateReadHandler_(this) {
190   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ")";
191   init();
192 }
193
194 AsyncSocket::AsyncSocket(EventBase* evb,
195                            const folly::SocketAddress& address,
196                            uint32_t connectTimeout)
197   : AsyncSocket(evb) {
198   connect(nullptr, address, connectTimeout);
199 }
200
201 AsyncSocket::AsyncSocket(EventBase* evb,
202                            const std::string& ip,
203                            uint16_t port,
204                            uint32_t connectTimeout)
205   : AsyncSocket(evb) {
206   connect(nullptr, ip, port, connectTimeout);
207 }
208
209 AsyncSocket::AsyncSocket(EventBase* evb, int fd)
210     : eventBase_(evb),
211       writeTimeout_(this, evb),
212       ioHandler_(this, evb, fd),
213       immediateReadHandler_(this) {
214   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd="
215           << fd << ")";
216   init();
217   fd_ = fd;
218   setCloseOnExec();
219   state_ = StateEnum::ESTABLISHED;
220 }
221
222 // init() method, since constructor forwarding isn't supported in most
223 // compilers yet.
224 void AsyncSocket::init() {
225   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
226   shutdownFlags_ = 0;
227   state_ = StateEnum::UNINIT;
228   eventFlags_ = EventHandler::NONE;
229   fd_ = -1;
230   sendTimeout_ = 0;
231   maxReadsPerEvent_ = 16;
232   connectCallback_ = nullptr;
233   readCallback_ = nullptr;
234   writeReqHead_ = nullptr;
235   writeReqTail_ = nullptr;
236   shutdownSocketSet_ = nullptr;
237   appBytesWritten_ = 0;
238   appBytesReceived_ = 0;
239 }
240
241 AsyncSocket::~AsyncSocket() {
242   VLOG(7) << "actual destruction of AsyncSocket(this=" << this
243           << ", evb=" << eventBase_ << ", fd=" << fd_
244           << ", state=" << state_ << ")";
245 }
246
247 void AsyncSocket::destroy() {
248   VLOG(5) << "AsyncSocket::destroy(this=" << this << ", evb=" << eventBase_
249           << ", fd=" << fd_ << ", state=" << state_;
250   // When destroy is called, close the socket immediately
251   closeNow();
252
253   // Then call DelayedDestruction::destroy() to take care of
254   // whether or not we need immediate or delayed destruction
255   DelayedDestruction::destroy();
256 }
257
258 int AsyncSocket::detachFd() {
259   VLOG(6) << "AsyncSocket::detachFd(this=" << this << ", fd=" << fd_
260           << ", evb=" << eventBase_ << ", state=" << state_
261           << ", events=" << std::hex << eventFlags_ << ")";
262   // Extract the fd, and set fd_ to -1 first, so closeNow() won't
263   // actually close the descriptor.
264   if (shutdownSocketSet_) {
265     shutdownSocketSet_->remove(fd_);
266   }
267   int fd = fd_;
268   fd_ = -1;
269   // Call closeNow() to invoke all pending callbacks with an error.
270   closeNow();
271   // Update the EventHandler to stop using this fd.
272   // This can only be done after closeNow() unregisters the handler.
273   ioHandler_.changeHandlerFD(-1);
274   return fd;
275 }
276
277 const folly::SocketAddress& AsyncSocket::anyAddress() {
278   static const folly::SocketAddress anyAddress =
279     folly::SocketAddress("0.0.0.0", 0);
280   return anyAddress;
281 }
282
283 void AsyncSocket::setShutdownSocketSet(ShutdownSocketSet* newSS) {
284   if (shutdownSocketSet_ == newSS) {
285     return;
286   }
287   if (shutdownSocketSet_ && fd_ != -1) {
288     shutdownSocketSet_->remove(fd_);
289   }
290   shutdownSocketSet_ = newSS;
291   if (shutdownSocketSet_ && fd_ != -1) {
292     shutdownSocketSet_->add(fd_);
293   }
294 }
295
296 void AsyncSocket::setCloseOnExec() {
297   int rv = fcntl(fd_, F_SETFD, FD_CLOEXEC);
298   if (rv != 0) {
299     auto errnoCopy = errno;
300     throw AsyncSocketException(
301         AsyncSocketException::INTERNAL_ERROR,
302         withAddr("failed to set close-on-exec flag"),
303         errnoCopy);
304   }
305 }
306
307 void AsyncSocket::connect(ConnectCallback* callback,
308                            const folly::SocketAddress& address,
309                            int timeout,
310                            const OptionMap &options,
311                            const folly::SocketAddress& bindAddr) noexcept {
312   DestructorGuard dg(this);
313   assert(eventBase_->isInEventBaseThread());
314
315   addr_ = address;
316
317   // Make sure we're in the uninitialized state
318   if (state_ != StateEnum::UNINIT) {
319     return invalidState(callback);
320   }
321
322   connectTimeout_ = std::chrono::milliseconds(timeout);
323   connectStartTime_ = std::chrono::steady_clock::now();
324   // Make connect end time at least >= connectStartTime.
325   connectEndTime_ = connectStartTime_;
326
327   assert(fd_ == -1);
328   state_ = StateEnum::CONNECTING;
329   connectCallback_ = callback;
330
331   sockaddr_storage addrStorage;
332   sockaddr* saddr = reinterpret_cast<sockaddr*>(&addrStorage);
333
334   try {
335     // Create the socket
336     // Technically the first parameter should actually be a protocol family
337     // constant (PF_xxx) rather than an address family (AF_xxx), but the
338     // distinction is mainly just historical.  In pretty much all
339     // implementations the PF_foo and AF_foo constants are identical.
340     fd_ = fsp::socket(address.getFamily(), SOCK_STREAM, 0);
341     if (fd_ < 0) {
342       auto errnoCopy = errno;
343       throw AsyncSocketException(
344           AsyncSocketException::INTERNAL_ERROR,
345           withAddr("failed to create socket"),
346           errnoCopy);
347     }
348     if (shutdownSocketSet_) {
349       shutdownSocketSet_->add(fd_);
350     }
351     ioHandler_.changeHandlerFD(fd_);
352
353     setCloseOnExec();
354
355     // Put the socket in non-blocking mode
356     int flags = fcntl(fd_, F_GETFL, 0);
357     if (flags == -1) {
358       auto errnoCopy = errno;
359       throw AsyncSocketException(
360           AsyncSocketException::INTERNAL_ERROR,
361           withAddr("failed to get socket flags"),
362           errnoCopy);
363     }
364     int rv = fcntl(fd_, F_SETFL, flags | O_NONBLOCK);
365     if (rv == -1) {
366       auto errnoCopy = errno;
367       throw AsyncSocketException(
368           AsyncSocketException::INTERNAL_ERROR,
369           withAddr("failed to put socket in non-blocking mode"),
370           errnoCopy);
371     }
372
373 #if !defined(MSG_NOSIGNAL) && defined(F_SETNOSIGPIPE)
374     // iOS and OS X don't support MSG_NOSIGNAL; set F_SETNOSIGPIPE instead
375     rv = fcntl(fd_, F_SETNOSIGPIPE, 1);
376     if (rv == -1) {
377       auto errnoCopy = errno;
378       throw AsyncSocketException(
379           AsyncSocketException::INTERNAL_ERROR,
380           "failed to enable F_SETNOSIGPIPE on socket",
381           errnoCopy);
382     }
383 #endif
384
385     // By default, turn on TCP_NODELAY
386     // If setNoDelay() fails, we continue anyway; this isn't a fatal error.
387     // setNoDelay() will log an error message if it fails.
388     if (address.getFamily() != AF_UNIX) {
389       (void)setNoDelay(true);
390     }
391
392     VLOG(5) << "AsyncSocket::connect(this=" << this << ", evb=" << eventBase_
393             << ", fd=" << fd_ << ", host=" << address.describe().c_str();
394
395     // bind the socket
396     if (bindAddr != anyAddress()) {
397       int one = 1;
398       if (setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) {
399         auto errnoCopy = errno;
400         doClose();
401         throw AsyncSocketException(
402             AsyncSocketException::NOT_OPEN,
403             "failed to setsockopt prior to bind on " + bindAddr.describe(),
404             errnoCopy);
405       }
406
407       bindAddr.getAddress(&addrStorage);
408
409       if (bind(fd_, saddr, bindAddr.getActualSize()) != 0) {
410         auto errnoCopy = errno;
411         doClose();
412         throw AsyncSocketException(
413             AsyncSocketException::NOT_OPEN,
414             "failed to bind to async socket: " + bindAddr.describe(),
415             errnoCopy);
416       }
417     }
418
419     // Apply the additional options if any.
420     for (const auto& opt: options) {
421       int rv = opt.first.apply(fd_, opt.second);
422       if (rv != 0) {
423         auto errnoCopy = errno;
424         throw AsyncSocketException(
425             AsyncSocketException::INTERNAL_ERROR,
426             withAddr("failed to set socket option"),
427             errnoCopy);
428       }
429     }
430
431     // Perform the connect()
432     address.getAddress(&addrStorage);
433
434     if (tfoEnabled_) {
435       state_ = StateEnum::FAST_OPEN;
436       tfoAttempted_ = true;
437     } else {
438       if (socketConnect(saddr, addr_.getActualSize()) < 0) {
439         return;
440       }
441     }
442
443     // If we're still here the connect() succeeded immediately.
444     // Fall through to call the callback outside of this try...catch block
445   } catch (const AsyncSocketException& ex) {
446     return failConnect(__func__, ex);
447   } catch (const std::exception& ex) {
448     // shouldn't happen, but handle it just in case
449     VLOG(4) << "AsyncSocket::connect(this=" << this << ", fd=" << fd_
450                << "): unexpected " << typeid(ex).name() << " exception: "
451                << ex.what();
452     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
453                             withAddr(string("unexpected exception: ") +
454                                      ex.what()));
455     return failConnect(__func__, tex);
456   }
457
458   // The connection succeeded immediately
459   // The read callback may not have been set yet, and no writes may be pending
460   // yet, so we don't have to register for any events at the moment.
461   VLOG(8) << "AsyncSocket::connect succeeded immediately; this=" << this;
462   assert(readCallback_ == nullptr);
463   assert(writeReqHead_ == nullptr);
464   if (state_ != StateEnum::FAST_OPEN) {
465     state_ = StateEnum::ESTABLISHED;
466   }
467   invokeConnectSuccess();
468 }
469
470 int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
471   int rv = fsp::connect(fd_, saddr, len);
472   if (rv < 0) {
473     auto errnoCopy = errno;
474     if (errnoCopy == EINPROGRESS) {
475       scheduleConnectTimeout();
476       registerForConnectEvents();
477     } else {
478       throw AsyncSocketException(
479           AsyncSocketException::NOT_OPEN,
480           "connect failed (immediately)",
481           errnoCopy);
482     }
483   }
484   return rv;
485 }
486
487 void AsyncSocket::scheduleConnectTimeout() {
488   // Connection in progress.
489   int timeout = connectTimeout_.count();
490   if (timeout > 0) {
491     // Start a timer in case the connection takes too long.
492     if (!writeTimeout_.scheduleTimeout(timeout)) {
493       throw AsyncSocketException(
494           AsyncSocketException::INTERNAL_ERROR,
495           withAddr("failed to schedule AsyncSocket connect timeout"));
496     }
497   }
498 }
499
500 void AsyncSocket::registerForConnectEvents() {
501   // Register for write events, so we'll
502   // be notified when the connection finishes/fails.
503   // Note that we don't register for a persistent event here.
504   assert(eventFlags_ == EventHandler::NONE);
505   eventFlags_ = EventHandler::WRITE;
506   if (!ioHandler_.registerHandler(eventFlags_)) {
507     throw AsyncSocketException(
508         AsyncSocketException::INTERNAL_ERROR,
509         withAddr("failed to register AsyncSocket connect handler"));
510   }
511 }
512
513 void AsyncSocket::connect(ConnectCallback* callback,
514                            const string& ip, uint16_t port,
515                            int timeout,
516                            const OptionMap &options) noexcept {
517   DestructorGuard dg(this);
518   try {
519     connectCallback_ = callback;
520     connect(callback, folly::SocketAddress(ip, port), timeout, options);
521   } catch (const std::exception& ex) {
522     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
523                             ex.what());
524     return failConnect(__func__, tex);
525   }
526 }
527
528 void AsyncSocket::cancelConnect() {
529   connectCallback_ = nullptr;
530   if (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN) {
531     closeNow();
532   }
533 }
534
535 void AsyncSocket::setSendTimeout(uint32_t milliseconds) {
536   sendTimeout_ = milliseconds;
537   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
538
539   // If we are currently pending on write requests, immediately update
540   // writeTimeout_ with the new value.
541   if ((eventFlags_ & EventHandler::WRITE) &&
542       (state_ != StateEnum::CONNECTING && state_ != StateEnum::FAST_OPEN)) {
543     assert(state_ == StateEnum::ESTABLISHED);
544     assert((shutdownFlags_ & SHUT_WRITE) == 0);
545     if (sendTimeout_ > 0) {
546       if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
547         AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
548             withAddr("failed to reschedule send timeout in setSendTimeout"));
549         return failWrite(__func__, ex);
550       }
551     } else {
552       writeTimeout_.cancelTimeout();
553     }
554   }
555 }
556
557 void AsyncSocket::setReadCB(ReadCallback *callback) {
558   VLOG(6) << "AsyncSocket::setReadCallback() this=" << this << ", fd=" << fd_
559           << ", callback=" << callback << ", state=" << state_;
560
561   // Short circuit if callback is the same as the existing readCallback_.
562   //
563   // Note that this is needed for proper functioning during some cleanup cases.
564   // During cleanup we allow setReadCallback(nullptr) to be called even if the
565   // read callback is already unset and we have been detached from an event
566   // base.  This check prevents us from asserting
567   // eventBase_->isInEventBaseThread() when eventBase_ is nullptr.
568   if (callback == readCallback_) {
569     return;
570   }
571
572   /* We are removing a read callback */
573   if (callback == nullptr &&
574       immediateReadHandler_.isLoopCallbackScheduled()) {
575     immediateReadHandler_.cancelLoopCallback();
576   }
577
578   if (shutdownFlags_ & SHUT_READ) {
579     // Reads have already been shut down on this socket.
580     //
581     // Allow setReadCallback(nullptr) to be called in this case, but don't
582     // allow a new callback to be set.
583     //
584     // For example, setReadCallback(nullptr) can happen after an error if we
585     // invoke some other error callback before invoking readError().  The other
586     // error callback that is invoked first may go ahead and clear the read
587     // callback before we get a chance to invoke readError().
588     if (callback != nullptr) {
589       return invalidState(callback);
590     }
591     assert((eventFlags_ & EventHandler::READ) == 0);
592     readCallback_ = nullptr;
593     return;
594   }
595
596   DestructorGuard dg(this);
597   assert(eventBase_->isInEventBaseThread());
598
599   switch ((StateEnum)state_) {
600     case StateEnum::CONNECTING:
601     case StateEnum::FAST_OPEN:
602       // For convenience, we allow the read callback to be set while we are
603       // still connecting.  We just store the callback for now.  Once the
604       // connection completes we'll register for read events.
605       readCallback_ = callback;
606       return;
607     case StateEnum::ESTABLISHED:
608     {
609       readCallback_ = callback;
610       uint16_t oldFlags = eventFlags_;
611       if (readCallback_) {
612         eventFlags_ |= EventHandler::READ;
613       } else {
614         eventFlags_ &= ~EventHandler::READ;
615       }
616
617       // Update our registration if our flags have changed
618       if (eventFlags_ != oldFlags) {
619         // We intentionally ignore the return value here.
620         // updateEventRegistration() will move us into the error state if it
621         // fails, and we don't need to do anything else here afterwards.
622         (void)updateEventRegistration();
623       }
624
625       if (readCallback_) {
626         checkForImmediateRead();
627       }
628       return;
629     }
630     case StateEnum::CLOSED:
631     case StateEnum::ERROR:
632       // We should never reach here.  SHUT_READ should always be set
633       // if we are in STATE_CLOSED or STATE_ERROR.
634       assert(false);
635       return invalidState(callback);
636     case StateEnum::UNINIT:
637       // We do not allow setReadCallback() to be called before we start
638       // connecting.
639       return invalidState(callback);
640   }
641
642   // We don't put a default case in the switch statement, so that the compiler
643   // will warn us to update the switch statement if a new state is added.
644   return invalidState(callback);
645 }
646
647 AsyncSocket::ReadCallback* AsyncSocket::getReadCallback() const {
648   return readCallback_;
649 }
650
651 void AsyncSocket::write(WriteCallback* callback,
652                          const void* buf, size_t bytes, WriteFlags flags) {
653   iovec op;
654   op.iov_base = const_cast<void*>(buf);
655   op.iov_len = bytes;
656   writeImpl(callback, &op, 1, unique_ptr<IOBuf>(), flags);
657 }
658
659 void AsyncSocket::writev(WriteCallback* callback,
660                           const iovec* vec,
661                           size_t count,
662                           WriteFlags flags) {
663   writeImpl(callback, vec, count, unique_ptr<IOBuf>(), flags);
664 }
665
666 void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr<IOBuf>&& buf,
667                               WriteFlags flags) {
668   constexpr size_t kSmallSizeMax = 64;
669   size_t count = buf->countChainElements();
670   if (count <= kSmallSizeMax) {
671     // suppress "warning: variable length array 'vec' is used [-Wvla]"
672     FOLLY_PUSH_WARNING;
673     FOLLY_GCC_DISABLE_WARNING(vla);
674     iovec vec[BOOST_PP_IF(FOLLY_HAVE_VLA, count, kSmallSizeMax)];
675     FOLLY_POP_WARNING;
676
677     writeChainImpl(callback, vec, count, std::move(buf), flags);
678   } else {
679     iovec* vec = new iovec[count];
680     writeChainImpl(callback, vec, count, std::move(buf), flags);
681     delete[] vec;
682   }
683 }
684
685 void AsyncSocket::writeChainImpl(WriteCallback* callback, iovec* vec,
686     size_t count, unique_ptr<IOBuf>&& buf, WriteFlags flags) {
687   size_t veclen = buf->fillIov(vec, count);
688   writeImpl(callback, vec, veclen, std::move(buf), flags);
689 }
690
691 void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
692                              size_t count, unique_ptr<IOBuf>&& buf,
693                              WriteFlags flags) {
694   VLOG(6) << "AsyncSocket::writev() this=" << this << ", fd=" << fd_
695           << ", callback=" << callback << ", count=" << count
696           << ", state=" << state_;
697   DestructorGuard dg(this);
698   unique_ptr<IOBuf>ioBuf(std::move(buf));
699   assert(eventBase_->isInEventBaseThread());
700
701   if (shutdownFlags_ & (SHUT_WRITE | SHUT_WRITE_PENDING)) {
702     // No new writes may be performed after the write side of the socket has
703     // been shutdown.
704     //
705     // We could just call callback->writeError() here to fail just this write.
706     // However, fail hard and use invalidState() to fail all outstanding
707     // callbacks and move the socket into the error state.  There's most likely
708     // a bug in the caller's code, so we abort everything rather than trying to
709     // proceed as best we can.
710     return invalidState(callback);
711   }
712
713   uint32_t countWritten = 0;
714   uint32_t partialWritten = 0;
715   int bytesWritten = 0;
716   bool mustRegister = false;
717   if ((state_ == StateEnum::ESTABLISHED || state_ == StateEnum::FAST_OPEN) &&
718       !connecting()) {
719     if (writeReqHead_ == nullptr) {
720       // If we are established and there are no other writes pending,
721       // we can attempt to perform the write immediately.
722       assert(writeReqTail_ == nullptr);
723       assert((eventFlags_ & EventHandler::WRITE) == 0);
724
725       auto writeResult =
726           performWrite(vec, count, flags, &countWritten, &partialWritten);
727       bytesWritten = writeResult.writeReturn;
728       if (bytesWritten < 0) {
729         auto errnoCopy = errno;
730         if (writeResult.exception) {
731           return failWrite(__func__, callback, 0, *writeResult.exception);
732         }
733         AsyncSocketException ex(
734             AsyncSocketException::INTERNAL_ERROR,
735             withAddr("writev failed"),
736             errnoCopy);
737         return failWrite(__func__, callback, 0, ex);
738       } else if (countWritten == count) {
739         // We successfully wrote everything.
740         // Invoke the callback and return.
741         if (callback) {
742           callback->writeSuccess();
743         }
744         return;
745       } else { // continue writing the next writeReq
746         if (bufferCallback_) {
747           bufferCallback_->onEgressBuffered();
748         }
749       }
750       if (!connecting()) {
751         // Writes might put the socket back into connecting state
752         // if TFO is enabled, and using TFO fails.
753         // This means that write timeouts would not be active, however
754         // connect timeouts would affect this stage.
755         mustRegister = true;
756       }
757     }
758   } else if (!connecting()) {
759     // Invalid state for writing
760     return invalidState(callback);
761   }
762
763   // Create a new WriteRequest to add to the queue
764   WriteRequest* req;
765   try {
766     req = BytesWriteRequest::newRequest(this, callback, vec + countWritten,
767                                         count - countWritten, partialWritten,
768                                         bytesWritten, std::move(ioBuf), flags);
769   } catch (const std::exception& ex) {
770     // we mainly expect to catch std::bad_alloc here
771     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
772         withAddr(string("failed to append new WriteRequest: ") + ex.what()));
773     return failWrite(__func__, callback, bytesWritten, tex);
774   }
775   req->consume();
776   if (writeReqTail_ == nullptr) {
777     assert(writeReqHead_ == nullptr);
778     writeReqHead_ = writeReqTail_ = req;
779   } else {
780     writeReqTail_->append(req);
781     writeReqTail_ = req;
782   }
783
784   // Register for write events if are established and not currently
785   // waiting on write events
786   if (mustRegister) {
787     assert(state_ == StateEnum::ESTABLISHED);
788     assert((eventFlags_ & EventHandler::WRITE) == 0);
789     if (!updateEventRegistration(EventHandler::WRITE, 0)) {
790       assert(state_ == StateEnum::ERROR);
791       return;
792     }
793     if (sendTimeout_ > 0) {
794       // Schedule a timeout to fire if the write takes too long.
795       if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
796         AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
797                                withAddr("failed to schedule send timeout"));
798         return failWrite(__func__, ex);
799       }
800     }
801   }
802 }
803
804 void AsyncSocket::writeRequest(WriteRequest* req) {
805   if (writeReqTail_ == nullptr) {
806     assert(writeReqHead_ == nullptr);
807     writeReqHead_ = writeReqTail_ = req;
808     req->start();
809   } else {
810     writeReqTail_->append(req);
811     writeReqTail_ = req;
812   }
813 }
814
815 void AsyncSocket::close() {
816   VLOG(5) << "AsyncSocket::close(): this=" << this << ", fd_=" << fd_
817           << ", state=" << state_ << ", shutdownFlags="
818           << std::hex << (int) shutdownFlags_;
819
820   // close() is only different from closeNow() when there are pending writes
821   // that need to drain before we can close.  In all other cases, just call
822   // closeNow().
823   //
824   // Note that writeReqHead_ can be non-nullptr even in STATE_CLOSED or
825   // STATE_ERROR if close() is invoked while a previous closeNow() or failure
826   // is still running.  (e.g., If there are multiple pending writes, and we
827   // call writeError() on the first one, it may call close().  In this case we
828   // will already be in STATE_CLOSED or STATE_ERROR, but the remaining pending
829   // writes will still be in the queue.)
830   //
831   // We only need to drain pending writes if we are still in STATE_CONNECTING
832   // or STATE_ESTABLISHED
833   if ((writeReqHead_ == nullptr) ||
834       !(state_ == StateEnum::CONNECTING ||
835       state_ == StateEnum::ESTABLISHED)) {
836     closeNow();
837     return;
838   }
839
840   // Declare a DestructorGuard to ensure that the AsyncSocket cannot be
841   // destroyed until close() returns.
842   DestructorGuard dg(this);
843   assert(eventBase_->isInEventBaseThread());
844
845   // Since there are write requests pending, we have to set the
846   // SHUT_WRITE_PENDING flag, and wait to perform the real close until the
847   // connect finishes and we finish writing these requests.
848   //
849   // Set SHUT_READ to indicate that reads are shut down, and set the
850   // SHUT_WRITE_PENDING flag to mark that we want to shutdown once the
851   // pending writes complete.
852   shutdownFlags_ |= (SHUT_READ | SHUT_WRITE_PENDING);
853
854   // If a read callback is set, invoke readEOF() immediately to inform it that
855   // the socket has been closed and no more data can be read.
856   if (readCallback_) {
857     // Disable reads if they are enabled
858     if (!updateEventRegistration(0, EventHandler::READ)) {
859       // We're now in the error state; callbacks have been cleaned up
860       assert(state_ == StateEnum::ERROR);
861       assert(readCallback_ == nullptr);
862     } else {
863       ReadCallback* callback = readCallback_;
864       readCallback_ = nullptr;
865       callback->readEOF();
866     }
867   }
868 }
869
870 void AsyncSocket::closeNow() {
871   VLOG(5) << "AsyncSocket::closeNow(): this=" << this << ", fd_=" << fd_
872           << ", state=" << state_ << ", shutdownFlags="
873           << std::hex << (int) shutdownFlags_;
874   DestructorGuard dg(this);
875   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
876
877   switch (state_) {
878     case StateEnum::ESTABLISHED:
879     case StateEnum::CONNECTING:
880     case StateEnum::FAST_OPEN: {
881       shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
882       state_ = StateEnum::CLOSED;
883
884       // If the write timeout was set, cancel it.
885       writeTimeout_.cancelTimeout();
886
887       // If we are registered for I/O events, unregister.
888       if (eventFlags_ != EventHandler::NONE) {
889         eventFlags_ = EventHandler::NONE;
890         if (!updateEventRegistration()) {
891           // We will have been moved into the error state.
892           assert(state_ == StateEnum::ERROR);
893           return;
894         }
895       }
896
897       if (immediateReadHandler_.isLoopCallbackScheduled()) {
898         immediateReadHandler_.cancelLoopCallback();
899       }
900
901       if (fd_ >= 0) {
902         ioHandler_.changeHandlerFD(-1);
903         doClose();
904       }
905
906       invokeConnectErr(socketClosedLocallyEx);
907
908       failAllWrites(socketClosedLocallyEx);
909
910       if (readCallback_) {
911         ReadCallback* callback = readCallback_;
912         readCallback_ = nullptr;
913         callback->readEOF();
914       }
915       return;
916     }
917     case StateEnum::CLOSED:
918       // Do nothing.  It's possible that we are being called recursively
919       // from inside a callback that we invoked inside another call to close()
920       // that is still running.
921       return;
922     case StateEnum::ERROR:
923       // Do nothing.  The error handling code has performed (or is performing)
924       // cleanup.
925       return;
926     case StateEnum::UNINIT:
927       assert(eventFlags_ == EventHandler::NONE);
928       assert(connectCallback_ == nullptr);
929       assert(readCallback_ == nullptr);
930       assert(writeReqHead_ == nullptr);
931       shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
932       state_ = StateEnum::CLOSED;
933       return;
934   }
935
936   LOG(DFATAL) << "AsyncSocket::closeNow() (this=" << this << ", fd=" << fd_
937               << ") called in unknown state " << state_;
938 }
939
940 void AsyncSocket::closeWithReset() {
941   // Enable SO_LINGER, with the linger timeout set to 0.
942   // This will trigger a TCP reset when we close the socket.
943   if (fd_ >= 0) {
944     struct linger optLinger = {1, 0};
945     if (setSockOpt(SOL_SOCKET, SO_LINGER, &optLinger) != 0) {
946       VLOG(2) << "AsyncSocket::closeWithReset(): error setting SO_LINGER "
947               << "on " << fd_ << ": errno=" << errno;
948     }
949   }
950
951   // Then let closeNow() take care of the rest
952   closeNow();
953 }
954
955 void AsyncSocket::shutdownWrite() {
956   VLOG(5) << "AsyncSocket::shutdownWrite(): this=" << this << ", fd=" << fd_
957           << ", state=" << state_ << ", shutdownFlags="
958           << std::hex << (int) shutdownFlags_;
959
960   // If there are no pending writes, shutdownWrite() is identical to
961   // shutdownWriteNow().
962   if (writeReqHead_ == nullptr) {
963     shutdownWriteNow();
964     return;
965   }
966
967   assert(eventBase_->isInEventBaseThread());
968
969   // There are pending writes.  Set SHUT_WRITE_PENDING so that the actual
970   // shutdown will be performed once all writes complete.
971   shutdownFlags_ |= SHUT_WRITE_PENDING;
972 }
973
974 void AsyncSocket::shutdownWriteNow() {
975   VLOG(5) << "AsyncSocket::shutdownWriteNow(): this=" << this
976           << ", fd=" << fd_ << ", state=" << state_
977           << ", shutdownFlags=" << std::hex << (int) shutdownFlags_;
978
979   if (shutdownFlags_ & SHUT_WRITE) {
980     // Writes are already shutdown; nothing else to do.
981     return;
982   }
983
984   // If SHUT_READ is already set, just call closeNow() to completely
985   // close the socket.  This can happen if close() was called with writes
986   // pending, and then shutdownWriteNow() is called before all pending writes
987   // complete.
988   if (shutdownFlags_ & SHUT_READ) {
989     closeNow();
990     return;
991   }
992
993   DestructorGuard dg(this);
994   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
995
996   switch (static_cast<StateEnum>(state_)) {
997     case StateEnum::ESTABLISHED:
998     {
999       shutdownFlags_ |= SHUT_WRITE;
1000
1001       // If the write timeout was set, cancel it.
1002       writeTimeout_.cancelTimeout();
1003
1004       // If we are registered for write events, unregister.
1005       if (!updateEventRegistration(0, EventHandler::WRITE)) {
1006         // We will have been moved into the error state.
1007         assert(state_ == StateEnum::ERROR);
1008         return;
1009       }
1010
1011       // Shutdown writes on the file descriptor
1012       shutdown(fd_, SHUT_WR);
1013
1014       // Immediately fail all write requests
1015       failAllWrites(socketShutdownForWritesEx);
1016       return;
1017     }
1018     case StateEnum::CONNECTING:
1019     {
1020       // Set the SHUT_WRITE_PENDING flag.
1021       // When the connection completes, it will check this flag,
1022       // shutdown the write half of the socket, and then set SHUT_WRITE.
1023       shutdownFlags_ |= SHUT_WRITE_PENDING;
1024
1025       // Immediately fail all write requests
1026       failAllWrites(socketShutdownForWritesEx);
1027       return;
1028     }
1029     case StateEnum::UNINIT:
1030       // Callers normally shouldn't call shutdownWriteNow() before the socket
1031       // even starts connecting.  Nonetheless, go ahead and set
1032       // SHUT_WRITE_PENDING.  Once the socket eventually connects it will
1033       // immediately shut down the write side of the socket.
1034       shutdownFlags_ |= SHUT_WRITE_PENDING;
1035       return;
1036     case StateEnum::FAST_OPEN:
1037       // In fast open state we haven't call connected yet, and if we shutdown
1038       // the writes, we will never try to call connect, so shut everything down
1039       shutdownFlags_ |= SHUT_WRITE;
1040       // Immediately fail all write requests
1041       failAllWrites(socketShutdownForWritesEx);
1042       return;
1043     case StateEnum::CLOSED:
1044     case StateEnum::ERROR:
1045       // We should never get here.  SHUT_WRITE should always be set
1046       // in STATE_CLOSED and STATE_ERROR.
1047       VLOG(4) << "AsyncSocket::shutdownWriteNow() (this=" << this
1048                  << ", fd=" << fd_ << ") in unexpected state " << state_
1049                  << " with SHUT_WRITE not set ("
1050                  << std::hex << (int) shutdownFlags_ << ")";
1051       assert(false);
1052       return;
1053   }
1054
1055   LOG(DFATAL) << "AsyncSocket::shutdownWriteNow() (this=" << this << ", fd="
1056               << fd_ << ") called in unknown state " << state_;
1057 }
1058
1059 bool AsyncSocket::readable() const {
1060   if (fd_ == -1) {
1061     return false;
1062   }
1063   struct pollfd fds[1];
1064   fds[0].fd = fd_;
1065   fds[0].events = POLLIN;
1066   fds[0].revents = 0;
1067   int rc = poll(fds, 1, 0);
1068   return rc == 1;
1069 }
1070
1071 bool AsyncSocket::isPending() const {
1072   return ioHandler_.isPending();
1073 }
1074
1075 bool AsyncSocket::hangup() const {
1076   if (fd_ == -1) {
1077     // sanity check, no one should ask for hangup if we are not connected.
1078     assert(false);
1079     return false;
1080   }
1081 #ifdef POLLRDHUP // Linux-only
1082   struct pollfd fds[1];
1083   fds[0].fd = fd_;
1084   fds[0].events = POLLRDHUP|POLLHUP;
1085   fds[0].revents = 0;
1086   poll(fds, 1, 0);
1087   return (fds[0].revents & (POLLRDHUP|POLLHUP)) != 0;
1088 #else
1089   return false;
1090 #endif
1091 }
1092
1093 bool AsyncSocket::good() const {
1094   return (
1095       (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN ||
1096        state_ == StateEnum::ESTABLISHED) &&
1097       (shutdownFlags_ == 0) && (eventBase_ != nullptr));
1098 }
1099
1100 bool AsyncSocket::error() const {
1101   return (state_ == StateEnum::ERROR);
1102 }
1103
1104 void AsyncSocket::attachEventBase(EventBase* eventBase) {
1105   VLOG(5) << "AsyncSocket::attachEventBase(this=" << this << ", fd=" << fd_
1106           << ", old evb=" << eventBase_ << ", new evb=" << eventBase
1107           << ", state=" << state_ << ", events="
1108           << std::hex << eventFlags_ << ")";
1109   assert(eventBase_ == nullptr);
1110   assert(eventBase->isInEventBaseThread());
1111
1112   eventBase_ = eventBase;
1113   ioHandler_.attachEventBase(eventBase);
1114   writeTimeout_.attachEventBase(eventBase);
1115 }
1116
1117 void AsyncSocket::detachEventBase() {
1118   VLOG(5) << "AsyncSocket::detachEventBase(this=" << this << ", fd=" << fd_
1119           << ", old evb=" << eventBase_ << ", state=" << state_
1120           << ", events=" << std::hex << eventFlags_ << ")";
1121   assert(eventBase_ != nullptr);
1122   assert(eventBase_->isInEventBaseThread());
1123
1124   eventBase_ = nullptr;
1125   ioHandler_.detachEventBase();
1126   writeTimeout_.detachEventBase();
1127 }
1128
1129 bool AsyncSocket::isDetachable() const {
1130   DCHECK(eventBase_ != nullptr);
1131   DCHECK(eventBase_->isInEventBaseThread());
1132
1133   return !ioHandler_.isHandlerRegistered() && !writeTimeout_.isScheduled();
1134 }
1135
1136 void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const {
1137   if (!localAddr_.isInitialized()) {
1138     localAddr_.setFromLocalAddress(fd_);
1139   }
1140   *address = localAddr_;
1141 }
1142
1143 void AsyncSocket::getPeerAddress(folly::SocketAddress* address) const {
1144   if (!addr_.isInitialized()) {
1145     addr_.setFromPeerAddress(fd_);
1146   }
1147   *address = addr_;
1148 }
1149
1150 int AsyncSocket::setNoDelay(bool noDelay) {
1151   if (fd_ < 0) {
1152     VLOG(4) << "AsyncSocket::setNoDelay() called on non-open socket "
1153                << this << "(state=" << state_ << ")";
1154     return EINVAL;
1155
1156   }
1157
1158   int value = noDelay ? 1 : 0;
1159   if (setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &value, sizeof(value)) != 0) {
1160     int errnoCopy = errno;
1161     VLOG(2) << "failed to update TCP_NODELAY option on AsyncSocket "
1162             << this << " (fd=" << fd_ << ", state=" << state_ << "): "
1163             << strerror(errnoCopy);
1164     return errnoCopy;
1165   }
1166
1167   return 0;
1168 }
1169
1170 int AsyncSocket::setCongestionFlavor(const std::string &cname) {
1171
1172   #ifndef TCP_CONGESTION
1173   #define TCP_CONGESTION  13
1174   #endif
1175
1176   if (fd_ < 0) {
1177     VLOG(4) << "AsyncSocket::setCongestionFlavor() called on non-open "
1178                << "socket " << this << "(state=" << state_ << ")";
1179     return EINVAL;
1180
1181   }
1182
1183   if (setsockopt(fd_, IPPROTO_TCP, TCP_CONGESTION, cname.c_str(),
1184         cname.length() + 1) != 0) {
1185     int errnoCopy = errno;
1186     VLOG(2) << "failed to update TCP_CONGESTION option on AsyncSocket "
1187             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1188             << strerror(errnoCopy);
1189     return errnoCopy;
1190   }
1191
1192   return 0;
1193 }
1194
1195 int AsyncSocket::setQuickAck(bool quickack) {
1196   if (fd_ < 0) {
1197     VLOG(4) << "AsyncSocket::setQuickAck() called on non-open socket "
1198                << this << "(state=" << state_ << ")";
1199     return EINVAL;
1200
1201   }
1202
1203 #ifdef TCP_QUICKACK // Linux-only
1204   int value = quickack ? 1 : 0;
1205   if (setsockopt(fd_, IPPROTO_TCP, TCP_QUICKACK, &value, sizeof(value)) != 0) {
1206     int errnoCopy = errno;
1207     VLOG(2) << "failed to update TCP_QUICKACK option on AsyncSocket"
1208             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1209             << strerror(errnoCopy);
1210     return errnoCopy;
1211   }
1212
1213   return 0;
1214 #else
1215   return ENOSYS;
1216 #endif
1217 }
1218
1219 int AsyncSocket::setSendBufSize(size_t bufsize) {
1220   if (fd_ < 0) {
1221     VLOG(4) << "AsyncSocket::setSendBufSize() called on non-open socket "
1222                << this << "(state=" << state_ << ")";
1223     return EINVAL;
1224   }
1225
1226   if (setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufsize, sizeof(bufsize)) !=0) {
1227     int errnoCopy = errno;
1228     VLOG(2) << "failed to update SO_SNDBUF option on AsyncSocket"
1229             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1230             << strerror(errnoCopy);
1231     return errnoCopy;
1232   }
1233
1234   return 0;
1235 }
1236
1237 int AsyncSocket::setRecvBufSize(size_t bufsize) {
1238   if (fd_ < 0) {
1239     VLOG(4) << "AsyncSocket::setRecvBufSize() called on non-open socket "
1240                << this << "(state=" << state_ << ")";
1241     return EINVAL;
1242   }
1243
1244   if (setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufsize, sizeof(bufsize)) !=0) {
1245     int errnoCopy = errno;
1246     VLOG(2) << "failed to update SO_RCVBUF option on AsyncSocket"
1247             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1248             << strerror(errnoCopy);
1249     return errnoCopy;
1250   }
1251
1252   return 0;
1253 }
1254
1255 int AsyncSocket::setTCPProfile(int profd) {
1256   if (fd_ < 0) {
1257     VLOG(4) << "AsyncSocket::setTCPProfile() called on non-open socket "
1258                << this << "(state=" << state_ << ")";
1259     return EINVAL;
1260   }
1261
1262   if (setsockopt(fd_, SOL_SOCKET, SO_SET_NAMESPACE, &profd, sizeof(int)) !=0) {
1263     int errnoCopy = errno;
1264     VLOG(2) << "failed to set socket namespace option on AsyncSocket"
1265             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1266             << strerror(errnoCopy);
1267     return errnoCopy;
1268   }
1269
1270   return 0;
1271 }
1272
1273 void AsyncSocket::ioReady(uint16_t events) noexcept {
1274   VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd" << fd_
1275           << ", events=" << std::hex << events << ", state=" << state_;
1276   DestructorGuard dg(this);
1277   assert(events & EventHandler::READ_WRITE);
1278   assert(eventBase_->isInEventBaseThread());
1279
1280   uint16_t relevantEvents = events & EventHandler::READ_WRITE;
1281   if (relevantEvents == EventHandler::READ) {
1282     handleRead();
1283   } else if (relevantEvents == EventHandler::WRITE) {
1284     handleWrite();
1285   } else if (relevantEvents == EventHandler::READ_WRITE) {
1286     EventBase* originalEventBase = eventBase_;
1287     // If both read and write events are ready, process writes first.
1288     handleWrite();
1289
1290     // Return now if handleWrite() detached us from our EventBase
1291     if (eventBase_ != originalEventBase) {
1292       return;
1293     }
1294
1295     // Only call handleRead() if a read callback is still installed.
1296     // (It's possible that the read callback was uninstalled during
1297     // handleWrite().)
1298     if (readCallback_) {
1299       handleRead();
1300     }
1301   } else {
1302     VLOG(4) << "AsyncSocket::ioRead() called with unexpected events "
1303                << std::hex << events << "(this=" << this << ")";
1304     abort();
1305   }
1306 }
1307
1308 AsyncSocket::ReadResult
1309 AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) {
1310   VLOG(5) << "AsyncSocket::performRead() this=" << this << ", buf=" << *buf
1311           << ", buflen=" << *buflen;
1312
1313   int recvFlags = 0;
1314   if (peek_) {
1315     recvFlags |= MSG_PEEK;
1316   }
1317
1318   ssize_t bytes = recv(fd_, *buf, *buflen, MSG_DONTWAIT | recvFlags);
1319   if (bytes < 0) {
1320     if (errno == EAGAIN || errno == EWOULDBLOCK) {
1321       // No more data to read right now.
1322       return ReadResult(READ_BLOCKING);
1323     } else {
1324       return ReadResult(READ_ERROR);
1325     }
1326   } else {
1327     appBytesReceived_ += bytes;
1328     return ReadResult(bytes);
1329   }
1330 }
1331
1332 void AsyncSocket::prepareReadBuffer(void** buf, size_t* buflen) noexcept {
1333   // no matter what, buffer should be preapared for non-ssl socket
1334   CHECK(readCallback_);
1335   readCallback_->getReadBuffer(buf, buflen);
1336 }
1337
1338 void AsyncSocket::handleRead() noexcept {
1339   VLOG(5) << "AsyncSocket::handleRead() this=" << this << ", fd=" << fd_
1340           << ", state=" << state_;
1341   assert(state_ == StateEnum::ESTABLISHED);
1342   assert((shutdownFlags_ & SHUT_READ) == 0);
1343   assert(readCallback_ != nullptr);
1344   assert(eventFlags_ & EventHandler::READ);
1345
1346   // Loop until:
1347   // - a read attempt would block
1348   // - readCallback_ is uninstalled
1349   // - the number of loop iterations exceeds the optional maximum
1350   // - this AsyncSocket is moved to another EventBase
1351   //
1352   // When we invoke readDataAvailable() it may uninstall the readCallback_,
1353   // which is why need to check for it here.
1354   //
1355   // The last bullet point is slightly subtle.  readDataAvailable() may also
1356   // detach this socket from this EventBase.  However, before
1357   // readDataAvailable() returns another thread may pick it up, attach it to
1358   // a different EventBase, and install another readCallback_.  We need to
1359   // exit immediately after readDataAvailable() returns if the eventBase_ has
1360   // changed.  (The caller must perform some sort of locking to transfer the
1361   // AsyncSocket between threads properly.  This will be sufficient to ensure
1362   // that this thread sees the updated eventBase_ variable after
1363   // readDataAvailable() returns.)
1364   uint16_t numReads = 0;
1365   EventBase* originalEventBase = eventBase_;
1366   while (readCallback_ && eventBase_ == originalEventBase) {
1367     // Get the buffer to read into.
1368     void* buf = nullptr;
1369     size_t buflen = 0, offset = 0;
1370     try {
1371       prepareReadBuffer(&buf, &buflen);
1372       VLOG(5) << "prepareReadBuffer() buf=" << buf << ", buflen=" << buflen;
1373     } catch (const AsyncSocketException& ex) {
1374       return failRead(__func__, ex);
1375     } catch (const std::exception& ex) {
1376       AsyncSocketException tex(AsyncSocketException::BAD_ARGS,
1377                               string("ReadCallback::getReadBuffer() "
1378                                      "threw exception: ") +
1379                               ex.what());
1380       return failRead(__func__, tex);
1381     } catch (...) {
1382       AsyncSocketException ex(AsyncSocketException::BAD_ARGS,
1383                              "ReadCallback::getReadBuffer() threw "
1384                              "non-exception type");
1385       return failRead(__func__, ex);
1386     }
1387     if (!isBufferMovable_ && (buf == nullptr || buflen == 0)) {
1388       AsyncSocketException ex(AsyncSocketException::BAD_ARGS,
1389                              "ReadCallback::getReadBuffer() returned "
1390                              "empty buffer");
1391       return failRead(__func__, ex);
1392     }
1393
1394     // Perform the read
1395     auto readResult = performRead(&buf, &buflen, &offset);
1396     auto bytesRead = readResult.readReturn;
1397     VLOG(4) << "this=" << this << ", AsyncSocket::handleRead() got "
1398             << bytesRead << " bytes";
1399     if (bytesRead > 0) {
1400       if (!isBufferMovable_) {
1401         readCallback_->readDataAvailable(bytesRead);
1402       } else {
1403         CHECK(kOpenSslModeMoveBufferOwnership);
1404         VLOG(5) << "this=" << this << ", AsyncSocket::handleRead() got "
1405                 << "buf=" << buf << ", " << bytesRead << "/" << buflen
1406                 << ", offset=" << offset;
1407         auto readBuf = folly::IOBuf::takeOwnership(buf, buflen);
1408         readBuf->trimStart(offset);
1409         readBuf->trimEnd(buflen - offset - bytesRead);
1410         readCallback_->readBufferAvailable(std::move(readBuf));
1411       }
1412
1413       // Fall through and continue around the loop if the read
1414       // completely filled the available buffer.
1415       // Note that readCallback_ may have been uninstalled or changed inside
1416       // readDataAvailable().
1417       if (size_t(bytesRead) < buflen) {
1418         return;
1419       }
1420     } else if (bytesRead == READ_BLOCKING) {
1421         // No more data to read right now.
1422         return;
1423     } else if (bytesRead == READ_ERROR) {
1424       readErr_ = READ_ERROR;
1425       if (readResult.exception) {
1426         return failRead(__func__, *readResult.exception);
1427       }
1428       auto errnoCopy = errno;
1429       AsyncSocketException ex(
1430           AsyncSocketException::INTERNAL_ERROR,
1431           withAddr("recv() failed"),
1432           errnoCopy);
1433       return failRead(__func__, ex);
1434     } else {
1435       assert(bytesRead == READ_EOF);
1436       readErr_ = READ_EOF;
1437       // EOF
1438       shutdownFlags_ |= SHUT_READ;
1439       if (!updateEventRegistration(0, EventHandler::READ)) {
1440         // we've already been moved into STATE_ERROR
1441         assert(state_ == StateEnum::ERROR);
1442         assert(readCallback_ == nullptr);
1443         return;
1444       }
1445
1446       ReadCallback* callback = readCallback_;
1447       readCallback_ = nullptr;
1448       callback->readEOF();
1449       return;
1450     }
1451     if (maxReadsPerEvent_ && (++numReads >= maxReadsPerEvent_)) {
1452       if (readCallback_ != nullptr) {
1453         // We might still have data in the socket.
1454         // (e.g. see comment in AsyncSSLSocket::checkForImmediateRead)
1455         scheduleImmediateRead();
1456       }
1457       return;
1458     }
1459   }
1460 }
1461
1462 /**
1463  * This function attempts to write as much data as possible, until no more data
1464  * can be written.
1465  *
1466  * - If it sends all available data, it unregisters for write events, and stops
1467  *   the writeTimeout_.
1468  *
1469  * - If not all of the data can be sent immediately, it reschedules
1470  *   writeTimeout_ (if a non-zero timeout is set), and ensures the handler is
1471  *   registered for write events.
1472  */
1473 void AsyncSocket::handleWrite() noexcept {
1474   VLOG(5) << "AsyncSocket::handleWrite() this=" << this << ", fd=" << fd_
1475           << ", state=" << state_;
1476   DestructorGuard dg(this);
1477
1478   if (state_ == StateEnum::CONNECTING) {
1479     handleConnect();
1480     return;
1481   }
1482
1483   // Normal write
1484   assert(state_ == StateEnum::ESTABLISHED);
1485   assert((shutdownFlags_ & SHUT_WRITE) == 0);
1486   assert(writeReqHead_ != nullptr);
1487
1488   // Loop until we run out of write requests,
1489   // or until this socket is moved to another EventBase.
1490   // (See the comment in handleRead() explaining how this can happen.)
1491   EventBase* originalEventBase = eventBase_;
1492   while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) {
1493     auto writeResult = writeReqHead_->performWrite();
1494     if (writeResult.writeReturn < 0) {
1495       if (writeResult.exception) {
1496         return failWrite(__func__, *writeResult.exception);
1497       }
1498       auto errnoCopy = errno;
1499       AsyncSocketException ex(
1500           AsyncSocketException::INTERNAL_ERROR,
1501           withAddr("writev() failed"),
1502           errnoCopy);
1503       return failWrite(__func__, ex);
1504     } else if (writeReqHead_->isComplete()) {
1505       // We finished this request
1506       WriteRequest* req = writeReqHead_;
1507       writeReqHead_ = req->getNext();
1508
1509       if (writeReqHead_ == nullptr) {
1510         writeReqTail_ = nullptr;
1511         // This is the last write request.
1512         // Unregister for write events and cancel the send timer
1513         // before we invoke the callback.  We have to update the state properly
1514         // before calling the callback, since it may want to detach us from
1515         // the EventBase.
1516         if (eventFlags_ & EventHandler::WRITE) {
1517           if (!updateEventRegistration(0, EventHandler::WRITE)) {
1518             assert(state_ == StateEnum::ERROR);
1519             return;
1520           }
1521           // Stop the send timeout
1522           writeTimeout_.cancelTimeout();
1523         }
1524         assert(!writeTimeout_.isScheduled());
1525
1526         // If SHUT_WRITE_PENDING is set, we should shutdown the socket after
1527         // we finish sending the last write request.
1528         //
1529         // We have to do this before invoking writeSuccess(), since
1530         // writeSuccess() may detach us from our EventBase.
1531         if (shutdownFlags_ & SHUT_WRITE_PENDING) {
1532           assert(connectCallback_ == nullptr);
1533           shutdownFlags_ |= SHUT_WRITE;
1534
1535           if (shutdownFlags_ & SHUT_READ) {
1536             // Reads have already been shutdown.  Fully close the socket and
1537             // move to STATE_CLOSED.
1538             //
1539             // Note: This code currently moves us to STATE_CLOSED even if
1540             // close() hasn't ever been called.  This can occur if we have
1541             // received EOF from the peer and shutdownWrite() has been called
1542             // locally.  Should we bother staying in STATE_ESTABLISHED in this
1543             // case, until close() is actually called?  I can't think of a
1544             // reason why we would need to do so.  No other operations besides
1545             // calling close() or destroying the socket can be performed at
1546             // this point.
1547             assert(readCallback_ == nullptr);
1548             state_ = StateEnum::CLOSED;
1549             if (fd_ >= 0) {
1550               ioHandler_.changeHandlerFD(-1);
1551               doClose();
1552             }
1553           } else {
1554             // Reads are still enabled, so we are only doing a half-shutdown
1555             shutdown(fd_, SHUT_WR);
1556           }
1557         }
1558       }
1559
1560       // Invoke the callback
1561       WriteCallback* callback = req->getCallback();
1562       req->destroy();
1563       if (callback) {
1564         callback->writeSuccess();
1565       }
1566       // We'll continue around the loop, trying to write another request
1567     } else {
1568       // Partial write.
1569       if (bufferCallback_) {
1570         bufferCallback_->onEgressBuffered();
1571       }
1572       writeReqHead_->consume();
1573       // Stop after a partial write; it's highly likely that a subsequent write
1574       // attempt will just return EAGAIN.
1575       //
1576       // Ensure that we are registered for write events.
1577       if ((eventFlags_ & EventHandler::WRITE) == 0) {
1578         if (!updateEventRegistration(EventHandler::WRITE, 0)) {
1579           assert(state_ == StateEnum::ERROR);
1580           return;
1581         }
1582       }
1583
1584       // Reschedule the send timeout, since we have made some write progress.
1585       if (sendTimeout_ > 0) {
1586         if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
1587           AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1588               withAddr("failed to reschedule write timeout"));
1589           return failWrite(__func__, ex);
1590         }
1591       }
1592       return;
1593     }
1594   }
1595   if (!writeReqHead_ && bufferCallback_) {
1596     bufferCallback_->onEgressBufferCleared();
1597   }
1598 }
1599
1600 void AsyncSocket::checkForImmediateRead() noexcept {
1601   // We currently don't attempt to perform optimistic reads in AsyncSocket.
1602   // (However, note that some subclasses do override this method.)
1603   //
1604   // Simply calling handleRead() here would be bad, as this would call
1605   // readCallback_->getReadBuffer(), forcing the callback to allocate a read
1606   // buffer even though no data may be available.  This would waste lots of
1607   // memory, since the buffer will sit around unused until the socket actually
1608   // becomes readable.
1609   //
1610   // Checking if the socket is readable now also seems like it would probably
1611   // be a pessimism.  In most cases it probably wouldn't be readable, and we
1612   // would just waste an extra system call.  Even if it is readable, waiting to
1613   // find out from libevent on the next event loop doesn't seem that bad.
1614 }
1615
1616 void AsyncSocket::handleInitialReadWrite() noexcept {
1617   // Our callers should already be holding a DestructorGuard, but grab
1618   // one here just to make sure, in case one of our calling code paths ever
1619   // changes.
1620   DestructorGuard dg(this);
1621   // If we have a readCallback_, make sure we enable read events.  We
1622   // may already be registered for reads if connectSuccess() set
1623   // the read calback.
1624   if (readCallback_ && !(eventFlags_ & EventHandler::READ)) {
1625     assert(state_ == StateEnum::ESTABLISHED);
1626     assert((shutdownFlags_ & SHUT_READ) == 0);
1627     if (!updateEventRegistration(EventHandler::READ, 0)) {
1628       assert(state_ == StateEnum::ERROR);
1629       return;
1630     }
1631     checkForImmediateRead();
1632   } else if (readCallback_ == nullptr) {
1633     // Unregister for read events.
1634     updateEventRegistration(0, EventHandler::READ);
1635   }
1636
1637   // If we have write requests pending, try to send them immediately.
1638   // Since we just finished accepting, there is a very good chance that we can
1639   // write without blocking.
1640   //
1641   // However, we only process them if EventHandler::WRITE is not already set,
1642   // which means that we're already blocked on a write attempt.  (This can
1643   // happen if connectSuccess() called write() before returning.)
1644   if (writeReqHead_ && !(eventFlags_ & EventHandler::WRITE)) {
1645     // Call handleWrite() to perform write processing.
1646     handleWrite();
1647   } else if (writeReqHead_ == nullptr) {
1648     // Unregister for write event.
1649     updateEventRegistration(0, EventHandler::WRITE);
1650   }
1651 }
1652
1653 void AsyncSocket::handleConnect() noexcept {
1654   VLOG(5) << "AsyncSocket::handleConnect() this=" << this << ", fd=" << fd_
1655           << ", state=" << state_;
1656   assert(state_ == StateEnum::CONNECTING);
1657   // SHUT_WRITE can never be set while we are still connecting;
1658   // SHUT_WRITE_PENDING may be set, be we only set SHUT_WRITE once the connect
1659   // finishes
1660   assert((shutdownFlags_ & SHUT_WRITE) == 0);
1661
1662   // In case we had a connect timeout, cancel the timeout
1663   writeTimeout_.cancelTimeout();
1664   // We don't use a persistent registration when waiting on a connect event,
1665   // so we have been automatically unregistered now.  Update eventFlags_ to
1666   // reflect reality.
1667   assert(eventFlags_ == EventHandler::WRITE);
1668   eventFlags_ = EventHandler::NONE;
1669
1670   // Call getsockopt() to check if the connect succeeded
1671   int error;
1672   socklen_t len = sizeof(error);
1673   int rv = getsockopt(fd_, SOL_SOCKET, SO_ERROR, &error, &len);
1674   if (rv != 0) {
1675     auto errnoCopy = errno;
1676     AsyncSocketException ex(
1677         AsyncSocketException::INTERNAL_ERROR,
1678         withAddr("error calling getsockopt() after connect"),
1679         errnoCopy);
1680     VLOG(4) << "AsyncSocket::handleConnect(this=" << this << ", fd="
1681                << fd_ << " host=" << addr_.describe()
1682                << ") exception:" << ex.what();
1683     return failConnect(__func__, ex);
1684   }
1685
1686   if (error != 0) {
1687     AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
1688                            "connect failed", error);
1689     VLOG(1) << "AsyncSocket::handleConnect(this=" << this << ", fd="
1690             << fd_ << " host=" << addr_.describe()
1691             << ") exception: " << ex.what();
1692     return failConnect(__func__, ex);
1693   }
1694
1695   // Move into STATE_ESTABLISHED
1696   state_ = StateEnum::ESTABLISHED;
1697
1698   // If SHUT_WRITE_PENDING is set and we don't have any write requests to
1699   // perform, immediately shutdown the write half of the socket.
1700   if ((shutdownFlags_ & SHUT_WRITE_PENDING) && writeReqHead_ == nullptr) {
1701     // SHUT_READ shouldn't be set.  If close() is called on the socket while we
1702     // are still connecting we just abort the connect rather than waiting for
1703     // it to complete.
1704     assert((shutdownFlags_ & SHUT_READ) == 0);
1705     shutdown(fd_, SHUT_WR);
1706     shutdownFlags_ |= SHUT_WRITE;
1707   }
1708
1709   VLOG(7) << "AsyncSocket " << this << ": fd " << fd_
1710           << "successfully connected; state=" << state_;
1711
1712   // Remember the EventBase we are attached to, before we start invoking any
1713   // callbacks (since the callbacks may call detachEventBase()).
1714   EventBase* originalEventBase = eventBase_;
1715
1716   invokeConnectSuccess();
1717   // Note that the connect callback may have changed our state.
1718   // (set or unset the read callback, called write(), closed the socket, etc.)
1719   // The following code needs to handle these situations correctly.
1720   //
1721   // If the socket has been closed, readCallback_ and writeReqHead_ will
1722   // always be nullptr, so that will prevent us from trying to read or write.
1723   //
1724   // The main thing to check for is if eventBase_ is still originalEventBase.
1725   // If not, we have been detached from this event base, so we shouldn't
1726   // perform any more operations.
1727   if (eventBase_ != originalEventBase) {
1728     return;
1729   }
1730
1731   handleInitialReadWrite();
1732 }
1733
1734 void AsyncSocket::timeoutExpired() noexcept {
1735   VLOG(7) << "AsyncSocket " << this << ", fd " << fd_ << ": timeout expired: "
1736           << "state=" << state_ << ", events=" << std::hex << eventFlags_;
1737   DestructorGuard dg(this);
1738   assert(eventBase_->isInEventBaseThread());
1739
1740   if (state_ == StateEnum::CONNECTING) {
1741     // connect() timed out
1742     // Unregister for I/O events.
1743     if (connectCallback_) {
1744       AsyncSocketException ex(
1745           AsyncSocketException::TIMED_OUT, "connect timed out");
1746       failConnect(__func__, ex);
1747     } else {
1748       // we faced a connect error without a connect callback, which could
1749       // happen due to TFO.
1750       AsyncSocketException ex(
1751           AsyncSocketException::TIMED_OUT, "write timed out during connection");
1752       failWrite(__func__, ex);
1753     }
1754   } else {
1755     // a normal write operation timed out
1756     AsyncSocketException ex(AsyncSocketException::TIMED_OUT, "write timed out");
1757     failWrite(__func__, ex);
1758   }
1759 }
1760
1761 ssize_t AsyncSocket::tfoSendMsg(int fd, struct msghdr* msg, int msg_flags) {
1762   return detail::tfo_sendmsg(fd, msg, msg_flags);
1763 }
1764
1765 AsyncSocket::WriteResult
1766 AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
1767   ssize_t totalWritten = 0;
1768   if (state_ == StateEnum::FAST_OPEN) {
1769     sockaddr_storage addr;
1770     auto len = addr_.getAddress(&addr);
1771     msg->msg_name = &addr;
1772     msg->msg_namelen = len;
1773     totalWritten = tfoSendMsg(fd_, msg, msg_flags);
1774     if (totalWritten >= 0) {
1775       tfoFinished_ = true;
1776       state_ = StateEnum::ESTABLISHED;
1777       // We schedule this asynchrously so that we don't end up
1778       // invoking initial read or write while a write is in progress.
1779       scheduleInitialReadWrite();
1780     } else if (errno == EINPROGRESS) {
1781       VLOG(4) << "TFO falling back to connecting";
1782       // A normal sendmsg doesn't return EINPROGRESS, however
1783       // TFO might fallback to connecting if there is no
1784       // cookie.
1785       state_ = StateEnum::CONNECTING;
1786       try {
1787         scheduleConnectTimeout();
1788         registerForConnectEvents();
1789       } catch (const AsyncSocketException& ex) {
1790         return WriteResult(
1791             WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
1792       }
1793       // Let's fake it that no bytes were written and return an errno.
1794       errno = EAGAIN;
1795       totalWritten = -1;
1796     } else if (errno == EOPNOTSUPP) {
1797       VLOG(4) << "TFO not supported";
1798       // Try falling back to connecting.
1799       state_ = StateEnum::CONNECTING;
1800       try {
1801         int ret = socketConnect((const sockaddr*)&addr, len);
1802         if (ret == 0) {
1803           // connect succeeded immediately
1804           // Treat this like no data was written.
1805           state_ = StateEnum::ESTABLISHED;
1806           scheduleInitialReadWrite();
1807         }
1808         // If there was no exception during connections,
1809         // we would return that no bytes were written.
1810         errno = EAGAIN;
1811         totalWritten = -1;
1812       } catch (const AsyncSocketException& ex) {
1813         return WriteResult(
1814             WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
1815       }
1816     } else if (errno == EAGAIN) {
1817       // Normally sendmsg would indicate that the write would block.
1818       // However in the fast open case, it would indicate that sendmsg
1819       // fell back to a connect. This is a return code from connect()
1820       // instead, and is an error condition indicating no fds available.
1821       return WriteResult(
1822           WRITE_ERROR,
1823           folly::make_unique<AsyncSocketException>(
1824               AsyncSocketException::UNKNOWN, "No more free local ports"));
1825     }
1826   } else {
1827     totalWritten = ::sendmsg(fd, msg, msg_flags);
1828   }
1829   return WriteResult(totalWritten);
1830 }
1831
1832 AsyncSocket::WriteResult AsyncSocket::performWrite(
1833     const iovec* vec,
1834     uint32_t count,
1835     WriteFlags flags,
1836     uint32_t* countWritten,
1837     uint32_t* partialWritten) {
1838   // We use sendmsg() instead of writev() so that we can pass in MSG_NOSIGNAL
1839   // We correctly handle EPIPE errors, so we never want to receive SIGPIPE
1840   // (since it may terminate the program if the main program doesn't explicitly
1841   // ignore it).
1842   struct msghdr msg;
1843   msg.msg_name = nullptr;
1844   msg.msg_namelen = 0;
1845   msg.msg_iov = const_cast<iovec *>(vec);
1846   msg.msg_iovlen = std::min<size_t>(count, kIovMax);
1847   msg.msg_control = nullptr;
1848   msg.msg_controllen = 0;
1849   msg.msg_flags = 0;
1850
1851   int msg_flags = MSG_DONTWAIT;
1852
1853 #ifdef MSG_NOSIGNAL // Linux-only
1854   msg_flags |= MSG_NOSIGNAL;
1855   if (isSet(flags, WriteFlags::CORK)) {
1856     // MSG_MORE tells the kernel we have more data to send, so wait for us to
1857     // give it the rest of the data rather than immediately sending a partial
1858     // frame, even when TCP_NODELAY is enabled.
1859     msg_flags |= MSG_MORE;
1860   }
1861 #endif
1862   if (isSet(flags, WriteFlags::EOR)) {
1863     // marks that this is the last byte of a record (response)
1864     msg_flags |= MSG_EOR;
1865   }
1866   auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
1867   auto totalWritten = writeResult.writeReturn;
1868   if (totalWritten < 0) {
1869     if (!writeResult.exception && errno == EAGAIN) {
1870       // TCP buffer is full; we can't write any more data right now.
1871       *countWritten = 0;
1872       *partialWritten = 0;
1873       return WriteResult(0);
1874     }
1875     // error
1876     *countWritten = 0;
1877     *partialWritten = 0;
1878     return writeResult;
1879   }
1880
1881   appBytesWritten_ += totalWritten;
1882
1883   uint32_t bytesWritten;
1884   uint32_t n;
1885   for (bytesWritten = totalWritten, n = 0; n < count; ++n) {
1886     const iovec* v = vec + n;
1887     if (v->iov_len > bytesWritten) {
1888       // Partial write finished in the middle of this iovec
1889       *countWritten = n;
1890       *partialWritten = bytesWritten;
1891       return WriteResult(totalWritten);
1892     }
1893
1894     bytesWritten -= v->iov_len;
1895   }
1896
1897   assert(bytesWritten == 0);
1898   *countWritten = n;
1899   *partialWritten = 0;
1900   return WriteResult(totalWritten);
1901 }
1902
1903 /**
1904  * Re-register the EventHandler after eventFlags_ has changed.
1905  *
1906  * If an error occurs, fail() is called to move the socket into the error state
1907  * and call all currently installed callbacks.  After an error, the
1908  * AsyncSocket is completely unregistered.
1909  *
1910  * @return Returns true on succcess, or false on error.
1911  */
1912 bool AsyncSocket::updateEventRegistration() {
1913   VLOG(5) << "AsyncSocket::updateEventRegistration(this=" << this
1914           << ", fd=" << fd_ << ", evb=" << eventBase_ << ", state=" << state_
1915           << ", events=" << std::hex << eventFlags_;
1916   assert(eventBase_->isInEventBaseThread());
1917   if (eventFlags_ == EventHandler::NONE) {
1918     ioHandler_.unregisterHandler();
1919     return true;
1920   }
1921
1922   // Always register for persistent events, so we don't have to re-register
1923   // after being called back.
1924   if (!ioHandler_.registerHandler(eventFlags_ | EventHandler::PERSIST)) {
1925     eventFlags_ = EventHandler::NONE; // we're not registered after error
1926     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1927         withAddr("failed to update AsyncSocket event registration"));
1928     fail("updateEventRegistration", ex);
1929     return false;
1930   }
1931
1932   return true;
1933 }
1934
1935 bool AsyncSocket::updateEventRegistration(uint16_t enable,
1936                                            uint16_t disable) {
1937   uint16_t oldFlags = eventFlags_;
1938   eventFlags_ |= enable;
1939   eventFlags_ &= ~disable;
1940   if (eventFlags_ == oldFlags) {
1941     return true;
1942   } else {
1943     return updateEventRegistration();
1944   }
1945 }
1946
1947 void AsyncSocket::startFail() {
1948   // startFail() should only be called once
1949   assert(state_ != StateEnum::ERROR);
1950   assert(getDestructorGuardCount() > 0);
1951   state_ = StateEnum::ERROR;
1952   // Ensure that SHUT_READ and SHUT_WRITE are set,
1953   // so all future attempts to read or write will be rejected
1954   shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
1955
1956   if (eventFlags_ != EventHandler::NONE) {
1957     eventFlags_ = EventHandler::NONE;
1958     ioHandler_.unregisterHandler();
1959   }
1960   writeTimeout_.cancelTimeout();
1961
1962   if (fd_ >= 0) {
1963     ioHandler_.changeHandlerFD(-1);
1964     doClose();
1965   }
1966 }
1967
1968 void AsyncSocket::finishFail() {
1969   assert(state_ == StateEnum::ERROR);
1970   assert(getDestructorGuardCount() > 0);
1971
1972   AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1973                          withAddr("socket closing after error"));
1974   invokeConnectErr(ex);
1975   failAllWrites(ex);
1976
1977   if (readCallback_) {
1978     ReadCallback* callback = readCallback_;
1979     readCallback_ = nullptr;
1980     callback->readErr(ex);
1981   }
1982 }
1983
1984 void AsyncSocket::fail(const char* fn, const AsyncSocketException& ex) {
1985   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
1986              << state_ << " host=" << addr_.describe()
1987              << "): failed in " << fn << "(): "
1988              << ex.what();
1989   startFail();
1990   finishFail();
1991 }
1992
1993 void AsyncSocket::failConnect(const char* fn, const AsyncSocketException& ex) {
1994   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
1995                << state_ << " host=" << addr_.describe()
1996                << "): failed while connecting in " << fn << "(): "
1997                << ex.what();
1998   startFail();
1999
2000   invokeConnectErr(ex);
2001   finishFail();
2002 }
2003
2004 void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) {
2005   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2006                << state_ << " host=" << addr_.describe()
2007                << "): failed while reading in " << fn << "(): "
2008                << ex.what();
2009   startFail();
2010
2011   if (readCallback_ != nullptr) {
2012     ReadCallback* callback = readCallback_;
2013     readCallback_ = nullptr;
2014     callback->readErr(ex);
2015   }
2016
2017   finishFail();
2018 }
2019
2020 void AsyncSocket::failWrite(const char* fn, const AsyncSocketException& ex) {
2021   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2022                << state_ << " host=" << addr_.describe()
2023                << "): failed while writing in " << fn << "(): "
2024                << ex.what();
2025   startFail();
2026
2027   // Only invoke the first write callback, since the error occurred while
2028   // writing this request.  Let any other pending write callbacks be invoked in
2029   // finishFail().
2030   if (writeReqHead_ != nullptr) {
2031     WriteRequest* req = writeReqHead_;
2032     writeReqHead_ = req->getNext();
2033     WriteCallback* callback = req->getCallback();
2034     uint32_t bytesWritten = req->getTotalBytesWritten();
2035     req->destroy();
2036     if (callback) {
2037       callback->writeErr(bytesWritten, ex);
2038     }
2039   }
2040
2041   finishFail();
2042 }
2043
2044 void AsyncSocket::failWrite(const char* fn, WriteCallback* callback,
2045                              size_t bytesWritten,
2046                              const AsyncSocketException& ex) {
2047   // This version of failWrite() is used when the failure occurs before
2048   // we've added the callback to writeReqHead_.
2049   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2050              << state_ << " host=" << addr_.describe()
2051              <<"): failed while writing in " << fn << "(): "
2052              << ex.what();
2053   startFail();
2054
2055   if (callback != nullptr) {
2056     callback->writeErr(bytesWritten, ex);
2057   }
2058
2059   finishFail();
2060 }
2061
2062 void AsyncSocket::failAllWrites(const AsyncSocketException& ex) {
2063   // Invoke writeError() on all write callbacks.
2064   // This is used when writes are forcibly shutdown with write requests
2065   // pending, or when an error occurs with writes pending.
2066   while (writeReqHead_ != nullptr) {
2067     WriteRequest* req = writeReqHead_;
2068     writeReqHead_ = req->getNext();
2069     WriteCallback* callback = req->getCallback();
2070     if (callback) {
2071       callback->writeErr(req->getTotalBytesWritten(), ex);
2072     }
2073     req->destroy();
2074   }
2075 }
2076
2077 void AsyncSocket::invalidState(ConnectCallback* callback) {
2078   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
2079              << "): connect() called in invalid state " << state_;
2080
2081   /*
2082    * The invalidState() methods don't use the normal failure mechanisms,
2083    * since we don't know what state we are in.  We don't want to call
2084    * startFail()/finishFail() recursively if we are already in the middle of
2085    * cleaning up.
2086    */
2087
2088   AsyncSocketException ex(AsyncSocketException::ALREADY_OPEN,
2089                          "connect() called with socket in invalid state");
2090   connectEndTime_ = std::chrono::steady_clock::now();
2091   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
2092     if (callback) {
2093       callback->connectErr(ex);
2094     }
2095   } else {
2096     // We can't use failConnect() here since connectCallback_
2097     // may already be set to another callback.  Invoke this ConnectCallback
2098     // here; any other connectCallback_ will be invoked in finishFail()
2099     startFail();
2100     if (callback) {
2101       callback->connectErr(ex);
2102     }
2103     finishFail();
2104   }
2105 }
2106
2107 void AsyncSocket::invokeConnectErr(const AsyncSocketException& ex) {
2108   connectEndTime_ = std::chrono::steady_clock::now();
2109   if (connectCallback_) {
2110     ConnectCallback* callback = connectCallback_;
2111     connectCallback_ = nullptr;
2112     callback->connectErr(ex);
2113   }
2114 }
2115
2116 void AsyncSocket::invokeConnectSuccess() {
2117   connectEndTime_ = std::chrono::steady_clock::now();
2118   if (connectCallback_) {
2119     ConnectCallback* callback = connectCallback_;
2120     connectCallback_ = nullptr;
2121     callback->connectSuccess();
2122   }
2123 }
2124
2125 void AsyncSocket::invalidState(ReadCallback* callback) {
2126   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
2127              << "): setReadCallback(" << callback
2128              << ") called in invalid state " << state_;
2129
2130   AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
2131                          "setReadCallback() called with socket in "
2132                          "invalid state");
2133   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
2134     if (callback) {
2135       callback->readErr(ex);
2136     }
2137   } else {
2138     startFail();
2139     if (callback) {
2140       callback->readErr(ex);
2141     }
2142     finishFail();
2143   }
2144 }
2145
2146 void AsyncSocket::invalidState(WriteCallback* callback) {
2147   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
2148              << "): write() called in invalid state " << state_;
2149
2150   AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
2151                          withAddr("write() called with socket in invalid state"));
2152   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
2153     if (callback) {
2154       callback->writeErr(0, ex);
2155     }
2156   } else {
2157     startFail();
2158     if (callback) {
2159       callback->writeErr(0, ex);
2160     }
2161     finishFail();
2162   }
2163 }
2164
2165 void AsyncSocket::doClose() {
2166   if (fd_ == -1) return;
2167   if (shutdownSocketSet_) {
2168     shutdownSocketSet_->close(fd_);
2169   } else {
2170     ::close(fd_);
2171   }
2172   fd_ = -1;
2173 }
2174
2175 std::ostream& operator << (std::ostream& os,
2176                            const AsyncSocket::StateEnum& state) {
2177   os << static_cast<int>(state);
2178   return os;
2179 }
2180
2181 std::string AsyncSocket::withAddr(const std::string& s) {
2182   // Don't use addr_ directly because it may not be initialized
2183   // e.g. if constructed from fd
2184   folly::SocketAddress peer, local;
2185   try {
2186     getPeerAddress(&peer);
2187     getLocalAddress(&local);
2188   } catch (const std::exception&) {
2189     // ignore
2190   } catch (...) {
2191     // ignore
2192   }
2193   return s + " (peer=" + peer.describe() + ", local=" + local.describe() + ")";
2194 }
2195
2196 void AsyncSocket::setBufferCallback(BufferCallback* cb) {
2197   bufferCallback_ = cb;
2198 }
2199
2200 } // folly