/*
- * Copyright 2015 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
#pragma once
#include <signal.h>
-#include <pthread.h>
-#include <folly/io/async/AsyncServerSocket.h>
+#include <folly/ExceptionWrapper.h>
+#include <folly/SocketAddress.h>
+#include <folly/experimental/TestUtil.h>
#include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncServerSocket.h>
#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/AsyncTransport.h>
#include <folly/io/async/EventBase.h>
-#include <folly/io/async/AsyncTimeout.h>
-#include <folly/SocketAddress.h>
+#include <folly/io/async/ssl/SSLErrors.h>
+#include <folly/io/async/test/TestSSLServer.h>
+#include <folly/portability/GTest.h>
+#include <folly/portability/PThread.h>
+#include <folly/portability/Sockets.h>
+#include <folly/portability/Unistd.h>
-#include <gtest/gtest.h>
-#include <iostream>
-#include <list>
-#include <unistd.h>
#include <fcntl.h>
-#include <poll.h>
#include <sys/types.h>
-#include <sys/socket.h>
-#include <netinet/tcp.h>
+#include <condition_variable>
+#include <iostream>
+#include <list>
namespace folly {
-enum StateEnum {
- STATE_WAITING,
- STATE_SUCCEEDED,
- STATE_FAILED
-};
-
// The destructors of all callback classes assert that the state is
// STATE_SUCCEEDED, for both possitive and negative tests. The tests
// are responsible for setting the succeeded state properly before the
// destructors are called.
+class SendMsgParamsCallbackBase :
+ public folly::AsyncSocket::SendMsgParamsCallback {
+ public:
+ SendMsgParamsCallbackBase() {}
+
+ void setSocket(
+ const std::shared_ptr<AsyncSSLSocket> &socket) {
+ socket_ = socket;
+ oldCallback_ = socket_->getSendMsgParamsCB();
+ socket_->setSendMsgParamCB(this);
+ }
+
+ int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
+ override {
+ return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
+ }
+
+ void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
+ oldCallback_->getAncillaryData(flags, data);
+ }
+
+ uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
+ return oldCallback_->getAncillaryDataSize(flags);
+ }
+
+ std::shared_ptr<AsyncSSLSocket> socket_;
+ folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
+};
+
+class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
+ public:
+ SendMsgFlagsCallback() {}
+
+ void resetFlags(int flags) {
+ flags_ = flags;
+ }
+
+ int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
+ override {
+ if (flags_) {
+ return flags_;
+ } else {
+ return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
+ }
+ }
+
+ int flags_{0};
+};
+
+class SendMsgDataCallback : public SendMsgFlagsCallback {
+ public:
+ SendMsgDataCallback() {}
+
+ void resetData(std::vector<char>&& data) {
+ ancillaryData_.swap(data);
+ }
+
+ void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
+ if (ancillaryData_.size()) {
+ std::cerr << "getAncillaryData: copying data" << std::endl;
+ memcpy(data, ancillaryData_.data(), ancillaryData_.size());
+ } else {
+ oldCallback_->getAncillaryData(flags, data);
+ }
+ }
+
+ uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
+ if (ancillaryData_.size()) {
+ std::cerr << "getAncillaryDataSize: returning size" << std::endl;
+ return ancillaryData_.size();
+ } else {
+ return oldCallback_->getAncillaryDataSize(flags);
+ }
+ }
+
+ std::vector<char> ancillaryData_;
+};
+
class WriteCallbackBase :
public AsyncTransportWrapper::WriteCallback {
-public:
- WriteCallbackBase()
+ public:
+ explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
: state(STATE_WAITING)
, bytesWritten(0)
- , exception(AsyncSocketException::UNKNOWN, "none") {}
+ , exception(AsyncSocketException::UNKNOWN, "none")
+ , mcb_(mcb) {}
- ~WriteCallbackBase() {
- EXPECT_EQ(state, STATE_SUCCEEDED);
+ ~WriteCallbackBase() override {
+ EXPECT_EQ(STATE_SUCCEEDED, state);
}
- void setSocket(
+ virtual void setSocket(
const std::shared_ptr<AsyncSSLSocket> &socket) {
socket_ = socket;
+ if (mcb_) {
+ mcb_->setSocket(socket);
+ }
}
void writeSuccess() noexcept override {
}
void writeErr(
- size_t bytesWritten,
+ size_t nBytesWritten,
const AsyncSocketException& ex) noexcept override {
- std::cerr << "writeError: bytesWritten " << bytesWritten
+ std::cerr << "writeError: bytesWritten " << nBytesWritten
<< ", exception " << ex.what() << std::endl;
state = STATE_FAILED;
- this->bytesWritten = bytesWritten;
+ this->bytesWritten = nBytesWritten;
exception = ex;
socket_->close();
- socket_->detachEventBase();
}
std::shared_ptr<AsyncSSLSocket> socket_;
StateEnum state;
size_t bytesWritten;
AsyncSocketException exception;
+ SendMsgParamsCallbackBase* mcb_;
};
+class ExpectWriteErrorCallback :
+public WriteCallbackBase {
+ public:
+ explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
+ : WriteCallbackBase(mcb) {}
+
+ ~ExpectWriteErrorCallback() override {
+ EXPECT_EQ(STATE_FAILED, state);
+ EXPECT_EQ(exception.type_,
+ AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
+ EXPECT_EQ(exception.errno_, 22);
+ // Suppress the assert in ~WriteCallbackBase()
+ state = STATE_SUCCEEDED;
+ }
+};
+
+#ifdef MSG_ERRQUEUE
+/* copied from include/uapi/linux/net_tstamp.h */
+/* SO_TIMESTAMPING gets an integer bit field comprised of these values */
+enum SOF_TIMESTAMPING {
+ SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
+ SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
+ SOF_TIMESTAMPING_OPT_ID = (1 << 7),
+ SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
+ SOF_TIMESTAMPING_TX_ACK = (1 << 9),
+ SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
+};
+
+class WriteCheckTimestampCallback :
+ public WriteCallbackBase {
+ public:
+ explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr)
+ : WriteCallbackBase(mcb) {}
+
+ ~WriteCheckTimestampCallback() override {
+ EXPECT_EQ(STATE_SUCCEEDED, state);
+ EXPECT_TRUE(gotTimestamp_);
+ EXPECT_TRUE(gotByteSeq_);
+ }
+
+ void setSocket(
+ const std::shared_ptr<AsyncSSLSocket> &socket) override {
+ WriteCallbackBase::setSocket(socket);
+
+ EXPECT_NE(socket_->getFd(), 0);
+ int flags = SOF_TIMESTAMPING_OPT_ID
+ | SOF_TIMESTAMPING_OPT_TSONLY
+ | SOF_TIMESTAMPING_SOFTWARE;
+ AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
+ int ret = tstampingOpt.apply(socket_->getFd(), flags);
+ EXPECT_EQ(ret, 0);
+ }
+
+ void checkForTimestampNotifications() noexcept {
+ int fd = socket_->getFd();
+ std::vector<char> ctrl(1024, 0);
+ unsigned char data;
+ struct msghdr msg;
+ iovec entry;
+
+ memset(&msg, 0, sizeof(msg));
+ entry.iov_base = &data;
+ entry.iov_len = sizeof(data);
+ msg.msg_iov = &entry;
+ msg.msg_iovlen = 1;
+ msg.msg_control = ctrl.data();
+ msg.msg_controllen = ctrl.size();
+
+ int ret;
+ while (true) {
+ ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
+ if (ret < 0) {
+ if (errno != EAGAIN) {
+ auto errnoCopy = errno;
+ std::cerr << "::recvmsg exited with code " << ret
+ << ", errno: " << errnoCopy << std::endl;
+ AsyncSocketException ex(
+ AsyncSocketException::INTERNAL_ERROR,
+ "recvmsg() failed",
+ errnoCopy);
+ exception = ex;
+ }
+ return;
+ }
+
+ for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ cmsg != nullptr && cmsg->cmsg_len != 0;
+ cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+ if (cmsg->cmsg_level == SOL_SOCKET &&
+ cmsg->cmsg_type == SCM_TIMESTAMPING) {
+ gotTimestamp_ = true;
+ continue;
+ }
+
+ if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
+ (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
+ gotByteSeq_ = true;
+ continue;
+ }
+ }
+ }
+ }
+
+ bool gotTimestamp_{false};
+ bool gotByteSeq_{false};
+};
+#endif // MSG_ERRQUEUE
+
class ReadCallbackBase :
public AsyncTransportWrapper::ReadCallback {
-public:
- explicit ReadCallbackBase(WriteCallbackBase *wcb)
- : wcb_(wcb)
- , state(STATE_WAITING) {}
+ public:
+ explicit ReadCallbackBase(WriteCallbackBase* wcb)
+ : wcb_(wcb), state(STATE_WAITING) {}
- ~ReadCallbackBase() {
- EXPECT_EQ(state, STATE_SUCCEEDED);
+ ~ReadCallbackBase() override {
+ EXPECT_EQ(STATE_SUCCEEDED, state);
}
void setSocket(
std::cerr << "readError " << ex.what() << std::endl;
state = STATE_FAILED;
socket_->close();
- socket_->detachEventBase();
}
void readEOF() noexcept override {
std::cerr << "readEOF" << std::endl;
socket_->close();
- socket_->detachEventBase();
}
std::shared_ptr<AsyncSSLSocket> socket_;
};
class ReadCallback : public ReadCallbackBase {
-public:
+ public:
explicit ReadCallback(WriteCallbackBase *wcb)
: ReadCallbackBase(wcb)
, buffers() {}
- ~ReadCallback() {
+ ~ReadCallback() override {
for (std::vector<Buffer>::iterator it = buffers.begin();
it != buffers.end();
++it) {
}
class Buffer {
- public:
+ public:
Buffer() : buffer(nullptr), length(0) {}
Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
buffer = nullptr;
length = 0;
}
- void allocate(size_t length) {
+ void allocate(size_t len) {
assert(buffer == nullptr);
- this->buffer = static_cast<char*>(malloc(length));
- this->length = length;
+ this->buffer = static_cast<char*>(malloc(len));
+ this->length = len;
}
void free() {
::free(buffer);
};
class ReadErrorCallback : public ReadCallbackBase {
-public:
+ public:
explicit ReadErrorCallback(WriteCallbackBase *wcb)
: ReadCallbackBase(wcb) {}
*lenReturn = 0;
}
- void readDataAvailable(size_t len) noexcept override {
+ void readDataAvailable(size_t /* len */) noexcept override {
// This should never to called.
FAIL();
}
}
};
+class ReadEOFCallback : public ReadCallbackBase {
+ public:
+ explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
+
+ // Return nullptr buffer to trigger readError()
+ void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ *bufReturn = nullptr;
+ *lenReturn = 0;
+ }
+
+ void readDataAvailable(size_t /* len */) noexcept override {
+ // This should never to called.
+ FAIL();
+ }
+
+ void readEOF() noexcept override {
+ ReadCallbackBase::readEOF();
+ setState(STATE_SUCCEEDED);
+ }
+};
+
class WriteErrorCallback : public ReadCallback {
-public:
+ public:
explicit WriteErrorCallback(WriteCallbackBase *wcb)
: ReadCallback(wcb) {}
wcb_->setSocket(socket_);
// Write back the same data.
- socket_->write(wcb_, currentBuffer.buffer, len);
+ folly::test::msvcSuppressAbortOnInvalidParams([&] {
+ socket_->write(wcb_, currentBuffer.buffer, len);
+ });
if (wcb_->state == STATE_FAILED) {
setState(STATE_SUCCEEDED);
};
class EmptyReadCallback : public ReadCallback {
-public:
+ public:
explicit EmptyReadCallback()
: ReadCallback(nullptr) {}
void readErr(const AsyncSocketException& ex) noexcept override {
std::cerr << "readError " << ex.what() << std::endl;
state = STATE_FAILED;
- tcpSocket_->close();
- tcpSocket_->detachEventBase();
+ if (tcpSocket_) {
+ tcpSocket_->close();
+ }
}
void readEOF() noexcept override {
std::cerr << "readEOF" << std::endl;
-
- tcpSocket_->close();
- tcpSocket_->detachEventBase();
+ if (tcpSocket_) {
+ tcpSocket_->close();
+ }
state = STATE_SUCCEEDED;
}
class HandshakeCallback :
public AsyncSSLSocket::HandshakeCB {
-public:
+ public:
enum ExpectType {
EXPECT_SUCCESS,
EXPECT_ERROR
// Functions inherited from AsyncSSLSocketHandshakeCallback
void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
+ std::lock_guard<std::mutex> g(mutex_);
+ cv_.notify_all();
EXPECT_EQ(sock, socket_.get());
std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
rcb_->setSocket(socket_);
sock->setReadCB(rcb_);
state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
}
- void handshakeErr(
- AsyncSSLSocket *sock,
- const AsyncSocketException& ex) noexcept override {
+ void handshakeErr(AsyncSSLSocket* /* sock */,
+ const AsyncSocketException& ex) noexcept override {
+ std::lock_guard<std::mutex> g(mutex_);
+ cv_.notify_all();
std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
if (expect_ == EXPECT_ERROR) {
// rcb will never be invoked
rcb_->setState(STATE_SUCCEEDED);
}
+ errorString_ = ex.what();
}
- ~HandshakeCallback() {
- EXPECT_EQ(state, STATE_SUCCEEDED);
+ void waitForHandshake() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ cv_.wait(lock, [this] { return state != STATE_WAITING; });
+ }
+
+ ~HandshakeCallback() override {
+ EXPECT_EQ(STATE_SUCCEEDED, state);
}
void closeSocket() {
state = STATE_SUCCEEDED;
}
+ std::shared_ptr<AsyncSSLSocket> getSocket() {
+ return socket_;
+ }
+
StateEnum state;
std::shared_ptr<AsyncSSLSocket> socket_;
ReadCallbackBase *rcb_;
ExpectType expect_;
-};
-
-class SSLServerAcceptCallbackBase:
-public folly::AsyncServerSocket::AcceptCallback {
-public:
- explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
- state(STATE_WAITING), hcb_(hcb) {}
-
- ~SSLServerAcceptCallbackBase() {
- EXPECT_EQ(state, STATE_SUCCEEDED);
- }
-
- void acceptError(const std::exception& ex) noexcept override {
- std::cerr << "SSLServerAcceptCallbackBase::acceptError "
- << ex.what() << std::endl;
- state = STATE_FAILED;
- }
-
- void connectionAccepted(int fd, const folly::SocketAddress& clientAddr)
- noexcept override{
- printf("Connection accepted\n");
- std::shared_ptr<AsyncSSLSocket> sslSock;
- try {
- // Create a AsyncSSLSocket object with the fd. The socket should be
- // added to the event base and in the state of accepting SSL connection.
- sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd);
- } catch (const std::exception &e) {
- LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
- "object with socket " << e.what() << fd;
- ::close(fd);
- acceptError(e);
- return;
- }
-
- connAccepted(sslSock);
- }
-
- virtual void connAccepted(
- const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
-
- StateEnum state;
- HandshakeCallback *hcb_;
- std::shared_ptr<folly::SSLContext> ctx_;
- folly::EventBase* base_;
+ std::mutex mutex_;
+ std::condition_variable cv_;
+ std::string errorString_;
};
class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
-public:
+ public:
uint32_t timeout_;
explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
SSLServerAcceptCallbackBase(hcb),
timeout_(timeout) {}
- virtual ~SSLServerAcceptCallback() {
+ ~SSLServerAcceptCallback() override {
if (timeout_ > 0) {
// if we set a timeout, we expect failure
EXPECT_EQ(hcb_->state, STATE_FAILED);
}
}
- // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
void connAccepted(
const std::shared_ptr<folly::AsyncSSLSocket> &s)
noexcept override {
std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
hcb_->setSocket(sock);
- sock->sslAccept(hcb_, timeout_);
+ sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
EXPECT_EQ(sock->getSSLState(),
AsyncSSLSocket::STATE_ACCEPTING);
};
class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
-public:
+ public:
explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
SSLServerAcceptCallback(hcb) {}
- // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
void connAccepted(
const std::shared_ptr<folly::AsyncSSLSocket> &s)
noexcept override {
};
class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
-public:
+ public:
explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
uint32_t timeout = 0):
SSLServerAcceptCallback(hcb, timeout) {}
- // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
void connAccepted(
const std::shared_ptr<folly::AsyncSSLSocket> &s)
noexcept override {
std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
hcb_->setSocket(sock);
- sock->sslAccept(hcb_, timeout_);
+ sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
ASSERT_TRUE((sock->getSSLState() ==
AsyncSSLSocket::STATE_ACCEPTING) ||
(sock->getSSLState() ==
class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
-public:
+ public:
explicit HandshakeErrorCallback(HandshakeCallback *hcb):
SSLServerAcceptCallbackBase(hcb) {}
- // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
void connAccepted(
const std::shared_ptr<folly::AsyncSSLSocket> &s)
noexcept override {
EXPECT_EQ(hcb_->state, STATE_FAILED);
EXPECT_EQ(callback2.state, STATE_FAILED);
- sock->detachEventBase();
-
state = STATE_SUCCEEDED;
hcb_->setState(STATE_SUCCEEDED);
callback2.setState(STATE_SUCCEEDED);
};
class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
-public:
+ public:
explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
SSLServerAcceptCallbackBase(hcb) {}
- // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
void connAccepted(
const std::shared_ptr<folly::AsyncSSLSocket> &s)
noexcept override {
}
};
-
-class TestSSLServer {
- protected:
- EventBase evb_;
- std::shared_ptr<folly::SSLContext> ctx_;
- SSLServerAcceptCallbackBase *acb_;
- folly::AsyncServerSocket *socket_;
- folly::SocketAddress address_;
- pthread_t thread_;
-
- static void *Main(void *ctx) {
- TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
- self->evb_.loop();
- std::cerr << "Server thread exited event loop" << std::endl;
- return nullptr;
- }
-
+class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
public:
- // Create a TestSSLServer.
- // This immediately starts listening on the given port.
- explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
-
- // Kill the thread.
- ~TestSSLServer() {
- evb_.runInEventBaseThread([&](){
- socket_->stopAccepting();
- });
- std::cerr << "Waiting for server thread to exit" << std::endl;
- pthread_join(thread_, nullptr);
+ ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
+ // We don't care if we get invoked or not.
+ // The client may time out and give up before connAccepted() is even
+ // called.
+ state = STATE_SUCCEEDED;
}
- EventBase &getEventBase() { return evb_; }
+ void connAccepted(
+ const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
+ std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
- const folly::SocketAddress& getAddress() const {
- return address_;
+ // Just wait a while before closing the socket, so the client
+ // will time out waiting for the handshake to complete.
+ s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
}
};
int lookupDelay = 100) :
TestSSLServer(acb) {
SSL_CTX *sslCtx = ctx_->getSSLCtx();
+#ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
SSL_CTX_sess_set_get_cb(sslCtx,
TestSSLAsyncCacheServer::getSessionCallback);
+#endif
SSL_CTX_set_session_cache_mode(
sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
asyncCallbacks_ = 0;
static uint32_t asyncLookups_;
static uint32_t lookupDelay_;
- static SSL_SESSION *getSessionCallback(SSL *ssl,
- unsigned char *sess_id,
- int id_len,
- int *copyflag) {
+ static SSL_SESSION* getSessionCallback(SSL* ssl,
+ unsigned char* /* sess_id */,
+ int /* id_len */,
+ int* copyflag) {
*copyflag = 0;
asyncCallbacks_++;
+ (void)ssl;
#ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
if (!SSL_want_sess_cache_lookup(ssl)) {
// libssl.so mismatch
}
}
- socket_->sslConn(this, 100);
+ socket_->sslConn(this, std::chrono::milliseconds(100));
}
struct iovec* getIovec() const {
bufSize_(2500 * 2000),
bytesRead_(0) {
buf_.reset(new uint8_t[bufSize_]);
- socket_->sslAccept(this, 100);
+ socket_->sslAccept(this, std::chrono::milliseconds(100));
}
void checkBuffer(struct iovec* iov, uint32_t count) const {
const unsigned char* nextProto;
unsigned nextProtoLength;
+ SSLContext::NextProtocolType protocolType;
+ folly::Optional<AsyncSocketException> except;
+
private:
void handshakeSuc(AsyncSSLSocket*) noexcept override {
- socket_->getSelectedNextProtocol(&nextProto,
- &nextProtoLength);
+ socket_->getSelectedNextProtocol(
+ &nextProto, &nextProtoLength, &protocolType);
}
void handshakeErr(
AsyncSSLSocket*,
const AsyncSocketException& ex) noexcept override {
- ADD_FAILURE() << "client handshake error: " << ex.what();
+ except = ex;
}
void writeSuccess() noexcept override {
socket_->close();
const unsigned char* nextProto;
unsigned nextProtoLength;
+ SSLContext::NextProtocolType protocolType;
+ folly::Optional<AsyncSocketException> except;
+
private:
void handshakeSuc(AsyncSSLSocket*) noexcept override {
- socket_->getSelectedNextProtocol(&nextProto,
- &nextProtoLength);
+ socket_->getSelectedNextProtocol(
+ &nextProto, &nextProtoLength, &protocolType);
}
void handshakeErr(
AsyncSSLSocket*,
const AsyncSocketException& ex) noexcept override {
- ADD_FAILURE() << "server handshake error: " << ex.what();
+ except = ex;
}
- void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
*lenReturn = 0;
}
- void readDataAvailable(size_t len) noexcept override {
- }
+ void readDataAvailable(size_t /* len */) noexcept override {}
void readEOF() noexcept override {
socket_->close();
}
AsyncSSLSocket::UniquePtr socket_;
};
+class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
+ public AsyncTransportWrapper::ReadCallback {
+ public:
+ explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
+ : socket_(std::move(socket)) {
+ socket_->sslAccept(this);
+ }
+
+ ~RenegotiatingServer() override {
+ socket_->setReadCB(nullptr);
+ }
+
+ void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
+ LOG(INFO) << "Renegotiating server handshake success";
+ socket_->setReadCB(this);
+ }
+ void handshakeErr(
+ AsyncSSLSocket*,
+ const AsyncSocketException& ex) noexcept override {
+ ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
+ }
+ void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ *lenReturn = sizeof(buf);
+ *bufReturn = buf;
+ }
+ void readDataAvailable(size_t /* len */) noexcept override {}
+ void readEOF() noexcept override {}
+ void readErr(const AsyncSocketException& ex) noexcept override {
+ LOG(INFO) << "server got read error " << ex.what();
+ auto exPtr = dynamic_cast<const SSLException*>(&ex);
+ ASSERT_NE(nullptr, exPtr);
+ std::string exStr(ex.what());
+ SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
+ ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
+ renegotiationError_ = true;
+ }
+
+ AsyncSSLSocket::UniquePtr socket_;
+ unsigned char buf[128];
+ bool renegotiationError_{false};
+};
+
#ifndef OPENSSL_NO_TLSEXT
class SNIClient :
private AsyncSSLSocket::HandshakeCB,
bool serverNameMatch;
private:
- void handshakeSuc(AsyncSSLSocket* ssl) noexcept override {}
+ void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
void handshakeErr(
AsyncSSLSocket*,
const AsyncSocketException& ex) noexcept override {
ADD_FAILURE() << "server handshake error: " << ex.what();
}
- void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
*lenReturn = 0;
}
- void readDataAvailable(size_t len) noexcept override {
- }
+ void readDataAvailable(size_t /* len */) noexcept override {}
void readEOF() noexcept override {
socket_->close();
}
uint32_t errors_;
uint32_t writeAfterConnectErrors_;
+ // These settings test that we eventually drain the
+ // socket, even if the maxReadsPerEvent_ is hit during
+ // a event loop iteration.
+ static constexpr size_t kMaxReadsPerEvent = 2;
+ // 2 event loop iterations
+ static constexpr size_t kMaxReadBufferSz =
+ sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
+
public:
SSLClient(EventBase *eventBase,
const folly::SocketAddress& address,
- uint32_t requests, uint32_t timeout = 0)
+ uint32_t requests,
+ uint32_t timeout = 0)
: eventBase_(eventBase),
session_(nullptr),
requests_(requests),
memset(buf_, 'a', sizeof(buf_));
}
- ~SSLClient() {
+ ~SSLClient() override {
if (session_) {
SSL_SESSION_free(session_);
}
}
// write()
+ sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
sslSocket_->write(this, buf_, sizeof(buf_));
sslSocket_->setReadCB(this);
memset(readbuf_, 'b', sizeof(readbuf_));
std::cerr << "client write success" << std::endl;
}
- void writeErr(
- size_t bytesWritten,
- const AsyncSocketException& ex)
- noexcept override {
+ void writeErr(size_t /* bytesWritten */,
+ const AsyncSocketException& ex) noexcept override {
std::cerr << "client writeError: " << ex.what() << std::endl;
if (!sslSocket_) {
writeAfterConnectErrors_++;
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
*bufReturn = readbuf_ + bytesRead_;
- *lenReturn = sizeof(readbuf_) - bytesRead_;
+ *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
}
void readEOF() noexcept override {
void readDataAvailable(size_t len) noexcept override {
std::cerr << "client read data: " << len << std::endl;
bytesRead_ += len;
- if (len == sizeof(buf_)) {
+ if (bytesRead_ == sizeof(buf_)) {
EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
sslSocket_->closeNow();
sslSocket_.reset();
verifyResult_(verifyResult) {
}
+ AsyncSSLSocket::UniquePtr moveSocket() && {
+ return std::move(socket_);
+ }
+
bool handshakeVerify_;
bool handshakeSuccess_;
bool handshakeError_;
+ std::chrono::nanoseconds handshakeTime;
protected:
AsyncSSLSocket::UniquePtr socket_;
bool verifyResult_;
// HandshakeCallback
- bool handshakeVer(
- AsyncSSLSocket* sock,
- bool preverifyOk,
- X509_STORE_CTX* ctx) noexcept override {
+ bool handshakeVer(AsyncSSLSocket* /* sock */,
+ bool preverifyOk,
+ X509_STORE_CTX* /* ctx */) noexcept override {
handshakeVerify_ = true;
EXPECT_EQ(preverifyResult_, preverifyOk);
}
void handshakeSuc(AsyncSSLSocket*) noexcept override {
+ LOG(INFO) << "Handshake success";
handshakeSuccess_ = true;
+ handshakeTime = socket_->getHandshakeTime();
}
void handshakeErr(
- AsyncSSLSocket*,
- const AsyncSocketException& ex) noexcept override {
+ AsyncSSLSocket*,
+ const AsyncSocketException& ex) noexcept override {
+ LOG(INFO) << "Handshake error " << ex.what();
handshakeError_ = true;
+ handshakeTime = socket_->getHandshakeTime();
}
// WriteCallback
bool preverifyResult,
bool verifyResult) :
SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
- socket_->sslConn(this, 0);
+ socket_->sslConn(this, std::chrono::milliseconds::zero());
}
};
bool preverifyResult,
bool verifyResult) :
SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
- socket_->sslConn(this, 0,
- folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+ socket_->sslConn(
+ this,
+ std::chrono::milliseconds::zero(),
+ folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
}
};
bool preverifyResult,
bool verifyResult) :
SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
- socket_->sslConn(this, 0,
- folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
+ socket_->sslConn(
+ this,
+ std::chrono::milliseconds::zero(),
+ folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
}
};
bool preverifyResult,
bool verifyResult)
: SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
- socket_->sslAccept(this, 0);
+ socket_->sslAccept(this, std::chrono::milliseconds::zero());
}
};
bool verifyResult)
: SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
socket_->enableClientHelloParsing();
- socket_->sslAccept(this, 0);
+ socket_->sslAccept(this, std::chrono::milliseconds::zero());
}
std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
bool preverifyResult,
bool verifyResult)
: SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
- socket_->sslAccept(this, 0,
- folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+ socket_->sslAccept(
+ this,
+ std::chrono::milliseconds::zero(),
+ folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
}
};
bool preverifyResult,
bool verifyResult)
: SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
- socket_->sslAccept(this, 0,
- folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
+ socket_->sslAccept(
+ this,
+ std::chrono::milliseconds::zero(),
+ folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
}
};