From 4708133f8d0c240fcb52e14a6c1c23f4479df1d1 Mon Sep 17 00:00:00 2001 From: Subodh Iyengar Date: Thu, 28 Apr 2016 12:10:01 -0700 Subject: [PATCH] Stop abusing errno Summary: We abuse errno to propagate exceptions from AsyncSSLSocket. Stop doing this and propagate exceptions correctly. This also formats the exception messages better. Reviewed By: anirudhvr Differential Revision: D3226808 fb-gh-sync-id: 15a5e67b0332136857e5fb85b1765757e548e040 fbshipit-source-id: 15a5e67b0332136857e5fb85b1765757e548e040 --- folly/Makefile.am | 2 + folly/io/async/AsyncSSLSocket.cpp | 172 ++++++------------ folly/io/async/AsyncSSLSocket.h | 39 ++-- folly/io/async/AsyncSocket.cpp | 61 ++++--- folly/io/async/AsyncSocket.h | 70 +++++-- folly/io/async/ssl/SSLErrors.cpp | 89 +++++++++ folly/io/async/ssl/SSLErrors.h | 63 +++++++ folly/io/async/test/AsyncSSLSocketTest.cpp | 78 +++++++- folly/io/async/test/AsyncSSLSocketTest.h | 95 +++++++++- .../io/async/test/AsyncSSLSocketWriteTest.cpp | 8 +- folly/io/async/test/BlockingSocket.h | 5 + 11 files changed, 488 insertions(+), 194 deletions(-) create mode 100644 folly/io/async/ssl/SSLErrors.cpp create mode 100644 folly/io/async/ssl/SSLErrors.h diff --git a/folly/Makefile.am b/folly/Makefile.am index f82d4898..17393794 100644 --- a/folly/Makefile.am +++ b/folly/Makefile.am @@ -234,6 +234,7 @@ nobase_follyinclude_HEADERS = \ io/async/HHWheelTimer.h \ io/async/ssl/OpenSSLPtrTypes.h \ io/async/ssl/OpenSSLUtils.h \ + io/async/ssl/SSLErrors.h \ io/async/ssl/TLSDefinitions.h \ io/async/Request.h \ io/async/SSLContext.h \ @@ -417,6 +418,7 @@ libfolly_la_SOURCES = \ io/async/test/SocketPair.cpp \ io/async/test/TimeUtil.cpp \ io/async/ssl/OpenSSLUtils.cpp \ + io/async/ssl/SSLErrors.cpp \ json.cpp \ detail/MemoryIdler.cpp \ MacAddress.cpp \ diff --git a/folly/io/async/AsyncSSLSocket.cpp b/folly/io/async/AsyncSSLSocket.cpp index 0decbbc8..2146d6b9 100644 --- a/folly/io/async/AsyncSSLSocket.cpp +++ b/folly/io/async/AsyncSSLSocket.cpp @@ -62,10 +62,6 @@ using folly::SSLContext; static SSLContext *dummyCtx = nullptr; static SpinLock dummyCtxLock; -// Numbers chosen as to not collide with functions in ssl.h -const uint8_t TASYNCSSLSOCKET_F_PERFORM_READ = 90; -const uint8_t TASYNCSSLSOCKET_F_PERFORM_WRITE = 91; - // If given min write size is less than this, buffer will be allocated on // stack, otherwise it is allocated on heap const size_t MAX_STACK_BUF_SIZE = 2048; @@ -246,39 +242,10 @@ void* initEorBioMethod(void) { return nullptr; } -std::string decodeOpenSSLError(int sslError, - unsigned long errError, - int sslOperationReturnValue) { - if (sslError == SSL_ERROR_SYSCALL && errError == 0) { - if (sslOperationReturnValue == 0) { - return "SSL_ERROR_SYSCALL: EOF"; - } else { - // In this case errno is set, AsyncSocketException will add it. - return "SSL_ERROR_SYSCALL"; - } - } else if (sslError == SSL_ERROR_ZERO_RETURN) { - // This signifies a TLS closure alert. - return "SSL_ERROR_ZERO_RETURN"; - } else { - char buf[256]; - std::string msg(ERR_error_string(errError, buf)); - return msg; - } -} - } // anonymous namespace namespace folly { -SSLException::SSLException(int sslError, - unsigned long errError, - int sslOperationReturnValue, - int errno_copy) - : AsyncSocketException( - AsyncSocketException::SSL_ERROR, - decodeOpenSSLError(sslError, errError, sslOperationReturnValue), - sslError == SSL_ERROR_SYSCALL ? errno_copy : 0) {} - /** * Create a client AsyncSSLSocket */ @@ -807,6 +774,10 @@ SSL_SESSION *AsyncSSLSocket::getSSLSession() { return sslSession_; } +const SSL* AsyncSSLSocket::getSSL() const { + return ssl_; +} + void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) { sslSession_ = session; if (!takeOwnership && session != nullptr) { @@ -967,8 +938,6 @@ bool AsyncSSLSocket::willBlock(int ret, // The timeout (if set) keeps running here return true; } else { - // SSL_ERROR_ZERO_RETURN is processed here so we can get some detail - // in the log unsigned long lastError = *errErrorOut = ERR_get_error(); VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", " << "state=" << state_ << ", " @@ -981,7 +950,6 @@ bool AsyncSSLSocket::willBlock(int ret, << "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", " << "func: " << ERR_func_error_string(lastError) << ", " << "reason: " << ERR_reason_error_string(lastError); - ERR_clear_error(); return false; } } @@ -1055,7 +1023,6 @@ AsyncSSLSocket::handleAccept() noexcept { SSL_set_msg_callback_arg(ssl_, this); } - errno = 0; int ret = SSL_accept(ssl_); if (ret <= 0) { int sslError; @@ -1119,7 +1086,6 @@ AsyncSSLSocket::handleConnect() noexcept { sslState_ == STATE_CONNECTING); assert(ssl_); - errno = 0; int ret = SSL_connect(ssl_); if (ret <= 0) { int sslError; @@ -1223,16 +1189,15 @@ AsyncSSLSocket::handleRead() noexcept { AsyncSocket::handleRead(); } -ssize_t +AsyncSocket::ReadResult AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) { - VLOG(4) << "AsyncSSLSocket::performRead() this=" << this - << ", buf=" << *buf << ", buflen=" << *buflen; + VLOG(4) << "AsyncSSLSocket::performRead() this=" << this << ", buf=" << *buf + << ", buflen=" << *buflen; if (sslState_ == STATE_UNENCRYPTED) { return AsyncSocket::performRead(buf, buflen, offset); } - errno = 0; ssize_t bytes = 0; if (!isBufferMovable_) { bytes = SSL_read(ssl_, *buf, *buflen); @@ -1247,20 +1212,18 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) { LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_) << ", sslstate=" << sslState_ << ", events=" << eventFlags_ << "): client intitiated SSL renegotiation not permitted"; - // We pack our own SSLerr here with a dummy function - errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ, - SSL_CLIENT_RENEGOTIATION_ATTEMPT); - ERR_clear_error(); - return READ_ERROR; + return ReadResult( + READ_ERROR, + folly::make_unique(SSLError::CLIENT_RENEGOTIATION)); } if (bytes <= 0) { int error = SSL_get_error(ssl_, bytes); if (error == SSL_ERROR_WANT_READ) { // The caller will register for read event if not already. if (errno == EWOULDBLOCK || errno == EAGAIN) { - return READ_BLOCKING; + return ReadResult(READ_BLOCKING); } else { - return READ_ERROR; + return ReadResult(READ_ERROR); } } else if (error == SSL_ERROR_WANT_WRITE) { // TODO: Even though we are attempting to read data, SSL_read() may @@ -1268,17 +1231,15 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) { // don't support this and just fail the read. LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_) << ", sslState=" << sslState_ << ", events=" << eventFlags_ - << "): unsupported SSL renegotiation during read", - errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ, - SSL_INVALID_RENEGOTIATION); - ERR_clear_error(); - return READ_ERROR; + << "): unsupported SSL renegotiation during read"; + return ReadResult( + READ_ERROR, + folly::make_unique(SSLError::INVALID_RENEGOTIATION)); } else { - // TODO: Fix this code so that it can return a proper error message - // to the callback, rather than relying on AsyncSocket code which - // can't handle SSL errors. - long lastError = ERR_get_error(); - + if (zero_return(error, bytes)) { + return ReadResult(bytes); + } + long errError = ERR_get_error(); VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", " << "state=" << state_ << ", " << "sslState=" << sslState_ << ", " @@ -1286,24 +1247,15 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) { << "bytes: " << bytes << ", " << "error: " << error << ", " << "errno: " << errno << ", " - << "func: " << ERR_func_error_string(lastError) << ", " - << "reason: " << ERR_reason_error_string(lastError); - ERR_clear_error(); - if (zero_return(error, bytes)) { - return bytes; - } - if (error != SSL_ERROR_SYSCALL) { - if ((unsigned long)lastError < 0x8000) { - errno = ENOSYS; - } else { - errno = lastError; - } - } - return READ_ERROR; + << "func: " << ERR_func_error_string(errError) << ", " + << "reason: " << ERR_reason_error_string(errError); + return ReadResult( + READ_ERROR, + folly::make_unique(error, errError, bytes, errno)); } } else { appBytesReceived_ += bytes; - return bytes; + return ReadResult(bytes); } } @@ -1331,49 +1283,40 @@ void AsyncSSLSocket::handleWrite() noexcept { AsyncSocket::handleWrite(); } -int AsyncSSLSocket::interpretSSLError(int rc, int error) { +AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) { if (error == SSL_ERROR_WANT_READ) { - // TODO: Even though we are attempting to write data, SSL_write() may + // Even though we are attempting to write data, SSL_write() may // need to read data if renegotiation is being performed. We currently // don't support this and just fail the write. LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_) << ", sslState=" << sslState_ << ", events=" << eventFlags_ - << "): " << "unsupported SSL renegotiation during write", - errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE, - SSL_INVALID_RENEGOTIATION); - ERR_clear_error(); - return -1; + << "): " + << "unsupported SSL renegotiation during write"; + return WriteResult( + WRITE_ERROR, + folly::make_unique(SSLError::INVALID_RENEGOTIATION)); } else { - // TODO: Fix this code so that it can return a proper error message - // to the callback, rather than relying on AsyncSocket code which - // can't handle SSL errors. - long lastError = ERR_get_error(); + if (zero_return(error, rc)) { + return WriteResult(0); + } + auto errError = ERR_get_error(); VLOG(3) << "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_) << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): " << "SSL error: " << error << ", errno: " << errno - << ", func: " << ERR_func_error_string(lastError) - << ", reason: " << ERR_reason_error_string(lastError); - if (error != SSL_ERROR_SYSCALL) { - if ((unsigned long)lastError < 0x8000) { - errno = ENOSYS; - } else { - errno = lastError; - } - } - ERR_clear_error(); - if (!zero_return(error, rc)) { - return -1; - } else { - return 0; - } + << ", func: " << ERR_func_error_string(errError) + << ", reason: " << ERR_reason_error_string(errError); + return WriteResult( + WRITE_ERROR, + folly::make_unique(error, errError, rc, errno)); } } -ssize_t AsyncSSLSocket::performWrite(const iovec* vec, - uint32_t count, - WriteFlags flags, - uint32_t* countWritten, - uint32_t* partialWritten) { +AsyncSocket::WriteResult AsyncSSLSocket::performWrite( + const iovec* vec, + uint32_t count, + WriteFlags flags, + uint32_t* countWritten, + uint32_t* partialWritten) { if (sslState_ == STATE_UNENCRYPTED) { return AsyncSocket::performWrite( vec, count, flags, countWritten, partialWritten); @@ -1384,9 +1327,8 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec, << ", events=" << eventFlags_ << "): " << "TODO: AsyncSSLSocket currently does not support calling " << "write() before the handshake has fully completed"; - errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE, - SSL_EARLY_WRITE); - return -1; + return WriteResult( + WRITE_ERROR, folly::make_unique(SSLError::EARLY_WRITE)); } bool cork = isSet(flags, WriteFlags::CORK); @@ -1420,7 +1362,6 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec, buf = ((const char*)v->iov_base) + offset; ssize_t bytes; - errno = 0; uint32_t buffersStolen = 0; if ((len < minWriteSize_) && ((i + 1) < count)) { // Combine this buffer with part or all of the next buffers in @@ -1474,11 +1415,11 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec, if (error == SSL_ERROR_WANT_WRITE) { // The caller will register for write event if not already. *partialWritten = offset; - return totalWritten; + return WriteResult(totalWritten); } - int rc = interpretSSLError(bytes, error); - if (rc < 0) { - return rc; + auto writeResult = interpretSSLError(bytes, error); + if (writeResult.writeReturn < 0) { + return writeResult; } // else fall through to below to correctly record totalWritten } @@ -1500,11 +1441,11 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec, v = &(vec[++i]); } *partialWritten = bytes; - return totalWritten; + return WriteResult(totalWritten); } } - return totalWritten; + return WriteResult(totalWritten); } int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n, @@ -1575,7 +1516,6 @@ int AsyncSSLSocket::eorAwareBioWrite(BIO *b, const char *in, int inl) { flags = MSG_EOR; } - errno = 0; ret = sendmsg(b->num, &msg, flags); BIO_clear_retry_flags(b); if (ret <= 0) { diff --git a/folly/io/async/AsyncSSLSocket.h b/folly/io/async/AsyncSSLSocket.h index 1bb3fc17..af3fd06b 100644 --- a/folly/io/async/AsyncSSLSocket.h +++ b/folly/io/async/AsyncSSLSocket.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -35,14 +36,6 @@ namespace folly { -class SSLException: public folly::AsyncSocketException { - public: - SSLException(int sslError, - unsigned long errError, - int sslOperationReturnValue, - int errno_copy); -}; - /** * A class for performing asynchronous I/O on an SSL connection. * @@ -143,18 +136,6 @@ class AsyncSSLSocket : public virtual AsyncSocket { AsyncSSLSocket* sslSocket_; }; - - /** - * These are passed to the application via errno, packed in an SSL err which - * are outside the valid errno range. The values are chosen to be unique - * against values in ssl.h - */ - enum SSLError { - SSL_CLIENT_RENEGOTIATION_ATTEMPT = 900, - SSL_INVALID_RENEGOTIATION = 901, - SSL_EARLY_WRITE = 902 - }; - /** * Create a client AsyncSSLSocket */ @@ -365,6 +346,11 @@ class AsyncSSLSocket : public virtual AsyncSocket { */ SSL_SESSION *getSSLSession(); + /** + * Get a handle to the SSL struct. + */ + const SSL* getSSL() const; + /** * Set the SSL session to be used during sslConn. AsyncSSLSocket will * hold a reference to the session until it is destroyed or released by the @@ -760,11 +746,14 @@ class AsyncSSLSocket : public virtual AsyncSocket { // AsyncSocket calls this at the wrong time for SSL void handleInitialReadWrite() noexcept override {} - int interpretSSLError(int rc, int error); - ssize_t performRead(void** buf, size_t* buflen, size_t* offset) override; - ssize_t performWrite(const iovec* vec, uint32_t count, WriteFlags flags, - uint32_t* countWritten, uint32_t* partialWritten) - override; + WriteResult interpretSSLError(int rc, int error); + ReadResult performRead(void** buf, size_t* buflen, size_t* offset) override; + WriteResult performWrite( + const iovec* vec, + uint32_t count, + WriteFlags flags, + uint32_t* countWritten, + uint32_t* partialWritten) override; ssize_t performWriteIovec(const iovec* vec, uint32_t count, WriteFlags flags, uint32_t* countWritten, diff --git a/folly/io/async/AsyncSocket.cpp b/folly/io/async/AsyncSocket.cpp index 6fd3c355..848cf5c4 100644 --- a/folly/io/async/AsyncSocket.cpp +++ b/folly/io/async/AsyncSocket.cpp @@ -91,14 +91,13 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest { free(this); } - bool performWrite() override { + WriteResult performWrite() override { WriteFlags writeFlags = flags_; if (getNext() != nullptr) { writeFlags = writeFlags | WriteFlags::CORK; } - bytesWritten_ = socket_->performWrite(getOps(), getOpCount(), writeFlags, - &opsWritten_, &partialBytes_); - return bytesWritten_ >= 0; + return socket_->performWrite( + getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_); } bool isComplete() override { @@ -694,10 +693,14 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec, assert(writeReqTail_ == nullptr); assert((eventFlags_ & EventHandler::WRITE) == 0); - bytesWritten = performWrite(vec, count, flags, - &countWritten, &partialWritten); + auto writeResult = + performWrite(vec, count, flags, &countWritten, &partialWritten); + bytesWritten = writeResult.writeReturn; if (bytesWritten < 0) { auto errnoCopy = errno; + if (writeResult.exception) { + return failWrite(__func__, callback, 0, *writeResult.exception); + } AsyncSocketException ex( AsyncSocketException::INTERNAL_ERROR, withAddr("writev failed"), @@ -1259,11 +1262,10 @@ void AsyncSocket::ioReady(uint16_t events) noexcept { } } -ssize_t AsyncSocket::performRead(void** buf, - size_t* buflen, - size_t* /* offset */) { - VLOG(5) << "AsyncSocket::performRead() this=" << this - << ", buf=" << *buf << ", buflen=" << *buflen; +AsyncSocket::ReadResult +AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) { + VLOG(5) << "AsyncSocket::performRead() this=" << this << ", buf=" << *buf + << ", buflen=" << *buflen; int recvFlags = 0; if (peek_) { @@ -1274,13 +1276,13 @@ ssize_t AsyncSocket::performRead(void** buf, if (bytes < 0) { if (errno == EAGAIN || errno == EWOULDBLOCK) { // No more data to read right now. - return READ_BLOCKING; + return ReadResult(READ_BLOCKING); } else { - return READ_ERROR; + return ReadResult(READ_ERROR); } } else { appBytesReceived_ += bytes; - return bytes; + return ReadResult(bytes); } } @@ -1347,7 +1349,8 @@ void AsyncSocket::handleRead() noexcept { } // Perform the read - ssize_t bytesRead = performRead(&buf, &buflen, &offset); + auto readResult = performRead(&buf, &buflen, &offset); + auto bytesRead = readResult.readReturn; VLOG(4) << "this=" << this << ", AsyncSocket::handleRead() got " << bytesRead << " bytes"; if (bytesRead > 0) { @@ -1376,6 +1379,9 @@ void AsyncSocket::handleRead() noexcept { return; } else if (bytesRead == READ_ERROR) { readErr_ = READ_ERROR; + if (readResult.exception) { + return failRead(__func__, *readResult.exception); + } auto errnoCopy = errno; AsyncSocketException ex( AsyncSocketException::INTERNAL_ERROR, @@ -1439,7 +1445,11 @@ void AsyncSocket::handleWrite() noexcept { // (See the comment in handleRead() explaining how this can happen.) EventBase* originalEventBase = eventBase_; while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) { - if (!writeReqHead_->performWrite()) { + auto writeResult = writeReqHead_->performWrite(); + if (writeResult.writeReturn < 0) { + if (writeResult.exception) { + return failWrite(__func__, *writeResult.exception); + } auto errnoCopy = errno; AsyncSocketException ex( AsyncSocketException::INTERNAL_ERROR, @@ -1697,11 +1707,12 @@ void AsyncSocket::timeoutExpired() noexcept { } } -ssize_t AsyncSocket::performWrite(const iovec* vec, - uint32_t count, - WriteFlags flags, - uint32_t* countWritten, - uint32_t* partialWritten) { +AsyncSocket::WriteResult AsyncSocket::performWrite( + const iovec* vec, + uint32_t count, + WriteFlags flags, + uint32_t* countWritten, + uint32_t* partialWritten) { // We use sendmsg() instead of writev() so that we can pass in MSG_NOSIGNAL // We correctly handle EPIPE errors, so we never want to receive SIGPIPE // (since it may terminate the program if the main program doesn't explicitly @@ -1736,12 +1747,12 @@ ssize_t AsyncSocket::performWrite(const iovec* vec, // TCP buffer is full; we can't write any more data right now. *countWritten = 0; *partialWritten = 0; - return 0; + return WriteResult(0); } // error *countWritten = 0; *partialWritten = 0; - return -1; + return WriteResult(WRITE_ERROR); } appBytesWritten_ += totalWritten; @@ -1754,7 +1765,7 @@ ssize_t AsyncSocket::performWrite(const iovec* vec, // Partial write finished in the middle of this iovec *countWritten = n; *partialWritten = bytesWritten; - return totalWritten; + return WriteResult(totalWritten); } bytesWritten -= v->iov_len; @@ -1763,7 +1774,7 @@ ssize_t AsyncSocket::performWrite(const iovec* vec, assert(bytesWritten == 0); *countWritten = n; *partialWritten = 0; - return totalWritten; + return WriteResult(totalWritten); } /** diff --git a/folly/io/async/AsyncSocket.h b/folly/io/async/AsyncSocket.h index ba706747..37fdf08e 100644 --- a/folly/io/async/AsyncSocket.h +++ b/folly/io/async/AsyncSocket.h @@ -16,16 +16,17 @@ #pragma once -#include -#include +#include #include -#include #include -#include +#include #include +#include #include -#include #include +#include +#include +#include #include #include @@ -517,6 +518,41 @@ class AsyncSocket : virtual public AsyncTransportWrapper { void setBufferCallback(BufferCallback* cb); + /** + * writeReturn is the total number of bytes written, or WRITE_ERROR on error. + * If no data has been written, 0 is returned. + * exception is a more specific exception that cause a write error. + * Not all writes have exceptions associated with them thus writeReturn + * should be checked to determine whether the operation resulted in an error. + */ + struct WriteResult { + explicit WriteResult(ssize_t ret) : writeReturn(ret) {} + + WriteResult(ssize_t ret, std::unique_ptr e) + : writeReturn(ret), exception(std::move(e)) {} + + ssize_t writeReturn; + std::unique_ptr exception; + }; + + /** + * readReturn is the number of bytes read, or READ_EOF on EOF, or + * READ_ERROR on error, or READ_BLOCKING if the operation will + * block. + * exception is a more specific exception that may have caused a read error. + * Not all read errors have exceptions associated with them thus readReturn + * should be checked to determine whether the operation resulted in an error. + */ + struct ReadResult { + explicit ReadResult(ssize_t ret) : readReturn(ret) {} + + ReadResult(ssize_t ret, std::unique_ptr e) + : readReturn(ret), exception(std::move(e)) {} + + ssize_t readReturn; + std::unique_ptr exception; + }; + /** * A WriteRequest object tracks information about a pending write operation. */ @@ -529,7 +565,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { virtual void destroy() = 0; - virtual bool performWrite() = 0; + virtual WriteResult performWrite() = 0; virtual void consume() = 0; @@ -579,6 +615,10 @@ class AsyncSocket : virtual public AsyncTransportWrapper { READ_NO_ERROR = -3, }; + enum WriteResultEnum { + WRITE_ERROR = -1, + }; + /** * Protected destructor. * @@ -683,11 +723,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper { * @param buf The buffer to read data into. * @param buflen The length of the buffer. * - * @return Returns the number of bytes read, or READ_EOF on EOF, or - * READ_ERROR on error, or READ_BLOCKING if the operation will - * block. + * @return Returns a read result. See read result for details. */ - virtual ssize_t performRead(void** buf, size_t* buflen, size_t* offset); + virtual ReadResult performRead(void** buf, size_t* buflen, size_t* offset); /** * Populate an iovec array from an IOBuf and attempt to write it. @@ -736,12 +774,14 @@ class AsyncSocket : virtual public AsyncTransportWrapper { * will contain the number of bytes written in the * partially written iovec entry. * - * @return Returns the total number of bytes written, or -1 on error. If no - * data can be written immediately, 0 is returned. + * @return Returns a WriteResult. See WriteResult for more details. */ - virtual ssize_t performWrite(const iovec* vec, uint32_t count, - WriteFlags flags, uint32_t* countWritten, - uint32_t* partialWritten); + virtual WriteResult performWrite( + const iovec* vec, + uint32_t count, + WriteFlags flags, + uint32_t* countWritten, + uint32_t* partialWritten); bool updateEventRegistration(); diff --git a/folly/io/async/ssl/SSLErrors.cpp b/folly/io/async/ssl/SSLErrors.cpp new file mode 100644 index 00000000..94550f5d --- /dev/null +++ b/folly/io/async/ssl/SSLErrors.cpp @@ -0,0 +1,89 @@ +/* + * Copyright 2016 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include + +using namespace folly; + +namespace { + +std::string decodeOpenSSLError( + int sslError, + unsigned long errError, + int sslOperationReturnValue) { + if (sslError == SSL_ERROR_SYSCALL && errError == 0) { + if (sslOperationReturnValue == 0) { + return "SSL_ERROR_SYSCALL: EOF"; + } else { + // In this case errno is set, AsyncSocketException will add it. + return "SSL_ERROR_SYSCALL"; + } + } else if (sslError == SSL_ERROR_ZERO_RETURN) { + // This signifies a TLS closure alert. + return "SSL_ERROR_ZERO_RETURN"; + } else { + std::array buf; + std::string msg(ERR_error_string(errError, buf.data())); + return msg; + } +} + +const StringPiece getSSLErrorString(SSLError error) { + StringPiece ret; + switch (error) { + case SSLError::CLIENT_RENEGOTIATION: + ret = "Client tried to renegotiate with server"; + break; + case SSLError::INVALID_RENEGOTIATION: + ret = "Attempt to start renegotiation, but unsupported"; + break; + case SSLError::EARLY_WRITE: + ret = "Attempt to write before SSL connection established"; + break; + case SSLError::OPENSSL_ERR: + // decodeOpenSSLError should be used for this type. + ret = "OPENSSL error"; + break; + } + return ret; +} +} + +namespace folly { + +SSLException::SSLException( + int sslError, + unsigned long errError, + int sslOperationReturnValue, + int errno_copy) + : AsyncSocketException( + AsyncSocketException::SSL_ERROR, + decodeOpenSSLError(sslError, errError, sslOperationReturnValue), + sslError == SSL_ERROR_SYSCALL ? errno_copy : 0), + sslError(SSLError::OPENSSL_ERR), + opensslSSLError(sslError), + opensslErr(errError) {} + +SSLException::SSLException(SSLError error) + : AsyncSocketException( + AsyncSocketException::SSL_ERROR, + getSSLErrorString(error).str(), + 0), + sslError(error) {} +} diff --git a/folly/io/async/ssl/SSLErrors.h b/folly/io/async/ssl/SSLErrors.h new file mode 100644 index 00000000..ad7a5de4 --- /dev/null +++ b/folly/io/async/ssl/SSLErrors.h @@ -0,0 +1,63 @@ +/* + * Copyright 2016 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace folly { + +enum class SSLError { + CLIENT_RENEGOTIATION, // A client tried to renegotiate with this server + INVALID_RENEGOTIATION, // We attempted to start a renegotiation. + EARLY_WRITE, // Wrote before SSL connection established. + // An openssl error type. The openssl specific methods should be used + // to find the real error type. + // This exists for compatibility until all error types can be move to proper + // errors. + OPENSSL_ERR, +}; + +class SSLException : public folly::AsyncSocketException { + public: + SSLException( + int sslError, + unsigned long errError, + int sslOperationReturnValue, + int errno_copy); + + explicit SSLException(SSLError error); + + SSLError getType() const { + return sslError; + } + + // These methods exist for compatibility until there are proper exceptions + // for all ssl error types. + int getOpensslSSLError() const { + return opensslSSLError; + } + + unsigned long getOpensslErr() const { + return opensslErr; + } + + private: + SSLError sslError; + int opensslSSLError; + unsigned long opensslErr; +}; +} diff --git a/folly/io/async/test/AsyncSSLSocketTest.cpp b/folly/io/async/test/AsyncSSLSocketTest.cpp index 038f4f31..bd0242aa 100644 --- a/folly/io/async/test/AsyncSSLSocketTest.cpp +++ b/folly/io/async/test/AsyncSSLSocketTest.cpp @@ -201,13 +201,89 @@ TEST(AsyncSSLSocketTest, ConnectWriteReadClose) { cerr << "ConnectWriteReadClose test completed" << endl; } +/** + * Test reading after server close. + */ +TEST(AsyncSSLSocketTest, ReadAfterClose) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadEOFCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallback acceptCallback(&handshakeCallback); + auto server = folly::make_unique(&acceptCallback); + + // Set up SSL context. + auto sslContext = std::make_shared(); + sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + + auto socket = + std::make_shared(server->getAddress(), sslContext); + socket->open(); + + // This should trigger an EOF on the client. + auto evb = handshakeCallback.getSocket()->getEventBase(); + evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); }); + std::array readbuf; + auto bytesRead = socket->read(readbuf.data(), readbuf.size()); + EXPECT_EQ(0, bytesRead); +} + +/** + * Test bad renegotiation + */ +TEST(AsyncSSLSocketTest, Renegotiate) { + EventBase eventBase; + auto clientCtx = std::make_shared(); + auto dfServerCtx = std::make_shared(); + std::array fds; + getfds(fds.data()); + getctx(clientCtx, dfServerCtx); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); + SSLHandshakeClient client(std::move(clientSock), true, true); + RenegotiatingServer server(std::move(serverSock)); + + while (!client.handshakeSuccess_ && !client.handshakeError_) { + eventBase.loopOnce(); + } + + ASSERT_TRUE(client.handshakeSuccess_); + + auto sslSock = std::move(client).moveSocket(); + sslSock->detachEventBase(); + // This is nasty, however we don't want to add support for + // renegotiation in AsyncSSLSocket. + SSL_renegotiate(const_cast(sslSock->getSSL())); + + auto socket = std::make_shared(std::move(sslSock)); + + std::thread t([&]() { eventBase.loopForever(); }); + + // Trigger the renegotiation. + std::array buf; + memset(buf.data(), 'a', buf.size()); + try { + socket->write(buf.data(), buf.size()); + } catch (AsyncSocketException& e) { + LOG(INFO) << "client got error " << e.what(); + } + eventBase.terminateLoopSoon(); + t.join(); + + eventBase.loop(); + ASSERT_TRUE(server.renegotiationError_); +} + /** * Negative test for handshakeError(). */ TEST(AsyncSSLSocketTest, HandshakeError) { // Start listening on a local port WriteCallbackBase writeCallback; - ReadCallback readCallback(&writeCallback); + WriteErrorCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); HandshakeErrorCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback); diff --git a/folly/io/async/test/AsyncSSLSocketTest.h b/folly/io/async/test/AsyncSSLSocketTest.h index a4e18aaa..69966d67 100644 --- a/folly/io/async/test/AsyncSSLSocketTest.h +++ b/folly/io/async/test/AsyncSSLSocketTest.h @@ -18,13 +18,15 @@ #include #include -#include +#include +#include #include +#include #include +#include #include #include -#include -#include +#include #include #include @@ -58,7 +60,7 @@ public: , exception(AsyncSocketException::UNKNOWN, "none") {} ~WriteCallbackBase() { - EXPECT_EQ(state, STATE_SUCCEEDED); + EXPECT_EQ(STATE_SUCCEEDED, state); } void setSocket( @@ -92,10 +94,9 @@ public: class ReadCallbackBase : public AsyncTransportWrapper::ReadCallback { -public: - explicit ReadCallbackBase(WriteCallbackBase *wcb) - : wcb_(wcb) - , state(STATE_WAITING) {} + public: + explicit ReadCallbackBase(WriteCallbackBase* wcb) + : wcb_(wcb), state(STATE_WAITING) {} ~ReadCallbackBase() { EXPECT_EQ(state, STATE_SUCCEEDED); @@ -222,6 +223,27 @@ public: } }; +class ReadEOFCallback : public ReadCallbackBase { + public: + explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {} + + // Return nullptr buffer to trigger readError() + void getReadBuffer(void** bufReturn, size_t* lenReturn) override { + *bufReturn = nullptr; + *lenReturn = 0; + } + + void readDataAvailable(size_t /* len */) noexcept override { + // This should never to called. + FAIL(); + } + + void readEOF() noexcept override { + ReadCallbackBase::readEOF(); + setState(STATE_SUCCEEDED); + } +}; + class WriteErrorCallback : public ReadCallback { public: explicit WriteErrorCallback(WriteCallbackBase *wcb) @@ -340,6 +362,10 @@ public: state = STATE_SUCCEEDED; } + std::shared_ptr getSocket() { + return socket_; + } + StateEnum state; std::shared_ptr socket_; ReadCallbackBase *rcb_; @@ -879,6 +905,48 @@ class NpnServer : AsyncSSLSocket::UniquePtr socket_; }; +class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB, + public AsyncTransportWrapper::ReadCallback { + public: + explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket) + : socket_(std::move(socket)) { + socket_->sslAccept(this); + } + + ~RenegotiatingServer() { + socket_->setReadCB(nullptr); + } + + void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override { + LOG(INFO) << "Renegotiating server handshake success"; + socket_->setReadCB(this); + } + void handshakeErr( + AsyncSSLSocket*, + const AsyncSocketException& ex) noexcept override { + ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what(); + } + void getReadBuffer(void** bufReturn, size_t* lenReturn) override { + *lenReturn = sizeof(buf); + *bufReturn = buf; + } + void readDataAvailable(size_t /* len */) noexcept override {} + void readEOF() noexcept override {} + void readErr(const AsyncSocketException& ex) noexcept override { + LOG(INFO) << "server got read error " << ex.what(); + auto exPtr = dynamic_cast(&ex); + ASSERT_NE(nullptr, exPtr); + std::string exStr(ex.what()); + SSLException sslEx(SSLError::CLIENT_RENEGOTIATION); + ASSERT_NE(std::string::npos, exStr.find(sslEx.what())); + renegotiationError_ = true; + } + + AsyncSSLSocket::UniquePtr socket_; + unsigned char buf[128]; + bool renegotiationError_{false}; +}; + #ifndef OPENSSL_NO_TLSEXT class SNIClient : private AsyncSSLSocket::HandshakeCB, @@ -1139,6 +1207,10 @@ class SSLHandshakeBase : verifyResult_(verifyResult) { } + AsyncSSLSocket::UniquePtr moveSocket() && { + return std::move(socket_); + } + bool handshakeVerify_; bool handshakeSuccess_; bool handshakeError_; @@ -1160,12 +1232,15 @@ class SSLHandshakeBase : } void handshakeSuc(AsyncSSLSocket*) noexcept override { + LOG(INFO) << "Handshake success"; handshakeSuccess_ = true; handshakeTime = socket_->getHandshakeTime(); } - void handshakeErr(AsyncSSLSocket*, - const AsyncSocketException& /* ex */) noexcept override { + void handshakeErr( + AsyncSSLSocket*, + const AsyncSocketException& ex) noexcept override { + LOG(INFO) << "Handshake error " << ex.what(); handshakeError_ = true; handshakeTime = socket_->getHandshakeTime(); } diff --git a/folly/io/async/test/AsyncSSLSocketWriteTest.cpp b/folly/io/async/test/AsyncSSLSocketWriteTest.cpp index 4497e65f..4235517d 100644 --- a/folly/io/async/test/AsyncSSLSocketWriteTest.cpp +++ b/folly/io/async/test/AsyncSSLSocketWriteTest.cpp @@ -58,8 +58,12 @@ class MockAsyncSSLSocket : public AsyncSSLSocket{ MOCK_CONST_METHOD0(getRawBytesWritten, size_t()); // public wrapper for protected interface - ssize_t testPerformWrite(const iovec* vec, uint32_t count, WriteFlags flags, - uint32_t* countWritten, uint32_t* partialWritten) { + WriteResult testPerformWrite( + const iovec* vec, + uint32_t count, + WriteFlags flags, + uint32_t* countWritten, + uint32_t* partialWritten) { return performWrite(vec, count, flags, countWritten, partialWritten); } diff --git a/folly/io/async/test/BlockingSocket.h b/folly/io/async/test/BlockingSocket.h index 7cfb870c..360cfcb1 100644 --- a/folly/io/async/test/BlockingSocket.h +++ b/folly/io/async/test/BlockingSocket.h @@ -35,6 +35,11 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback, new folly::AsyncSocket(&eventBase_)), address_(address) {} + explicit BlockingSocket(folly::AsyncSocket::UniquePtr socket) + : sock_(std::move(socket)) { + sock_->attachEventBase(&eventBase_); + } + void open() { sock_->connect(this, address_); eventBase_.loop(); -- 2.34.1