Fix AsyncSocket::handleRead
authorSarang Masti <mssarang@fb.com>
Wed, 27 May 2015 03:19:36 +0000 (20:19 -0700)
committerNoam Lerner <noamler@fb.com>
Wed, 3 Jun 2015 16:47:44 +0000 (09:47 -0700)
Summary:
If openssl has buffered data read from the socket, we might not get a
read event on the socket. So, we must schedule a readCallback to ensure
before exiting from AsyncSocket::handleRead if we have exhausted the
maxReadsPerEvent_ limit.

Test Plan: -- modifying existing test to test this corner case.

Reviewed By: davejwatson@fb.com

Subscribers: net-systems@, trunkagent, folly-diffs@, yfeldblum, chalfant

FB internal diff: D2102601

Tasks: 7168699

Signature: t1:2102601:1432837605:82e72a2a1875c08c9c1e8e831796c9c90df927fa

folly/io/async/AsyncSocket.cpp
folly/io/async/AsyncSocket.h
folly/io/async/test/AsyncSSLSocketTest.cpp
folly/io/async/test/AsyncSSLSocketTest.h

index f477d5da6174bb8fd0203d6e81cc5665fb60a173..939788d7d22e0bde00ce8d96ea0f5e3d58c54df4 100644 (file)
@@ -221,7 +221,8 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
 AsyncSocket::AsyncSocket()
   : eventBase_(nullptr)
   , writeTimeout_(this, nullptr)
-  , ioHandler_(this, nullptr) {
+  , ioHandler_(this, nullptr)
+  , immediateReadHandler_(this) {
   VLOG(5) << "new AsyncSocket()";
   init();
 }
@@ -229,7 +230,8 @@ AsyncSocket::AsyncSocket()
 AsyncSocket::AsyncSocket(EventBase* evb)
   : eventBase_(evb)
   , writeTimeout_(this, evb)
-  , ioHandler_(this, evb) {
+  , ioHandler_(this, evb)
+  , immediateReadHandler_(this) {
   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ")";
   init();
 }
@@ -252,7 +254,8 @@ AsyncSocket::AsyncSocket(EventBase* evb,
 AsyncSocket::AsyncSocket(EventBase* evb, int fd)
   : eventBase_(evb)
   , writeTimeout_(this, evb)
-  , ioHandler_(this, evb, fd) {
+  , ioHandler_(this, evb, fd)
+  , immediateReadHandler_(this) {
   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd="
           << fd << ")";
   init();
@@ -852,6 +855,10 @@ void AsyncSocket::closeNow() {
         }
       }
 
+      if (immediateReadHandler_.isLoopCallbackScheduled()) {
+        immediateReadHandler_.cancelLoopCallback();
+      }
+
       if (fd_ >= 0) {
         ioHandler_.changeHandlerFD(-1);
         doClose();
@@ -1357,6 +1364,9 @@ void AsyncSocket::handleRead() noexcept {
       return;
     }
     if (maxReadsPerEvent_ && (++numReads >= maxReadsPerEvent_)) {
+      // We might still have data in the socket.
+      // (e.g. see comment in AsyncSSLSocket::checkForImmediateRead)
+      scheduleImmediateRead();
       return;
     }
   }
index 866c5d91284f56a409f7789c334a6ca40924b8c7..85d7275ddb86ef900dde4f799d1d64055a164487 100644 (file)
@@ -552,6 +552,26 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
 
   void init();
 
+  class ImmediateReadCB : public folly::EventBase::LoopCallback {
+   public:
+    explicit ImmediateReadCB(AsyncSocket* socket) : socket_(socket) {}
+    void runLoopCallback() noexcept override {
+      socket_->checkForImmediateRead();
+    }
+   private:
+    AsyncSocket* socket_;
+  };
+
+  /**
+   * Schedule checkForImmediateRead to be executed in the next loop
+   * iteration.
+   */
+  void scheduleImmediateRead() noexcept {
+    if (good()) {
+      eventBase_->runInLoop(&immediateReadHandler_);
+    }
+  }
+
   // event notification methods
   void ioReady(uint16_t events) noexcept;
   virtual void checkForImmediateRead() noexcept;
@@ -673,6 +693,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   EventBase* eventBase_;               ///< The EventBase
   WriteTimeout writeTimeout_;           ///< A timeout for connect and write
   IoHandler ioHandler_;                 ///< A EventHandler to monitor the fd
+  ImmediateReadCB immediateReadHandler_; ///< LoopCallback for checking read
 
   ConnectCallback* connectCallback_;    ///< ConnectCallback
   ReadCallback* readCallback_;          ///< ReadCallback
index 5f1ca7304083b800fa45cc1ca06844b41b41e556..de88b059420701ca9f0185f73079b46050cf4721 100644 (file)
@@ -52,6 +52,9 @@ const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
 
+constexpr size_t SSLClient::kMaxReadBufferSz;
+constexpr size_t SSLClient::kMaxReadsPerEvent;
+
 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase *acb) :
 ctx_(new folly::SSLContext),
     acb_(acb),
index 78623f77c92433814e0eb61645043b658798f874..a08fe6b949340568f3df5cc72259a27334b10b49 100644 (file)
@@ -980,10 +980,18 @@ class SSLClient : public AsyncSocket::ConnectCallback,
   uint32_t errors_;
   uint32_t writeAfterConnectErrors_;
 
+  // These settings test that we eventually drain the
+  // socket, even if the maxReadsPerEvent_ is hit during
+  // a event loop iteration.
+  static constexpr size_t kMaxReadsPerEvent = 2;
+  static constexpr size_t kMaxReadBufferSz =
+    sizeof(readbuf_) / kMaxReadsPerEvent / 2;  // 2 event loop iterations
+
  public:
   SSLClient(EventBase *eventBase,
             const folly::SocketAddress& address,
-            uint32_t requests, uint32_t timeout = 0)
+            uint32_t requests,
+            uint32_t timeout = 0)
       : eventBase_(eventBase),
         session_(nullptr),
         requests_(requests),
@@ -1046,6 +1054,7 @@ class SSLClient : public AsyncSocket::ConnectCallback,
     }
 
     // write()
+    sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
     sslSocket_->write(this, buf_, sizeof(buf_));
     sslSocket_->setReadCB(this);
     memset(readbuf_, 'b', sizeof(readbuf_));
@@ -1075,7 +1084,7 @@ class SSLClient : public AsyncSocket::ConnectCallback,
 
   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
     *bufReturn = readbuf_ + bytesRead_;
-    *lenReturn = sizeof(readbuf_) - bytesRead_;
+    *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
   }
 
   void readEOF() noexcept override {
@@ -1090,7 +1099,7 @@ class SSLClient : public AsyncSocket::ConnectCallback,
   void readDataAvailable(size_t len) noexcept override {
     std::cerr << "client read data: " << len << std::endl;
     bytesRead_ += len;
-    if (len == sizeof(buf_)) {
+    if (bytesRead_ == sizeof(buf_)) {
       EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
       sslSocket_->closeNow();
       sslSocket_.reset();