2 * Copyright 2017 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
18 #include <folly/futures/Promise.h>
19 #include <folly/init/Init.h>
20 #include <folly/io/async/AsyncSSLSocket.h>
21 #include <folly/io/async/EventBase.h>
22 #include <folly/io/async/SSLContext.h>
23 #include <folly/io/async/ScopedEventBaseThread.h>
24 #include <folly/portability/GTest.h>
25 #include <folly/portability/PThread.h>
36 struct EvbAndContext {
38 ctx_.reset(new SSLContext());
39 ctx_->setOptions(SSL_OP_NO_TICKET);
40 ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
43 std::shared_ptr<AsyncSSLSocket> createSocket() {
44 return AsyncSSLSocket::newSocket(ctx_, getEventBase());
47 EventBase* getEventBase() {
48 return evb_.getEventBase();
51 void attach(AsyncSSLSocket& socket) {
52 socket.attachEventBase(getEventBase());
53 socket.attachSSLContext(ctx_);
56 folly::ScopedEventBaseThread evb_;
57 std::shared_ptr<SSLContext> ctx_;
60 class AttachDetachClient : public AsyncSocket::ConnectCallback,
61 public AsyncTransportWrapper::WriteCallback,
62 public AsyncTransportWrapper::ReadCallback {
64 // two threads here - we'll create the socket in one, connect
65 // in the other, and then read/write in the initial one
68 std::shared_ptr<AsyncSSLSocket> sslSocket_;
69 folly::SocketAddress address_;
73 // promise to fulfill when done
74 folly::Promise<bool> promise_;
77 sslSocket_->detachEventBase();
78 sslSocket_->detachSSLContext();
82 explicit AttachDetachClient(const folly::SocketAddress& address)
83 : address_(address), bytesRead_(0) {}
85 Future<bool> getFuture() {
86 return promise_.getFuture();
90 // create in one and then move to another
91 auto t1Evb = t1_.getEventBase();
92 t1Evb->runInEventBaseThread([this] {
93 sslSocket_ = t1_.createSocket();
94 // ensure we can detach and reattach the context before connecting
95 for (int i = 0; i < 1000; ++i) {
96 sslSocket_->detachSSLContext();
97 sslSocket_->attachSSLContext(t1_.ctx_);
99 // detach from t1 and connect in t2
101 auto t2Evb = t2_.getEventBase();
102 t2Evb->runInEventBaseThread([this] {
103 t2_.attach(*sslSocket_);
104 sslSocket_->connect(this, address_);
109 void connectSuccess() noexcept override {
110 auto t2Evb = t2_.getEventBase();
111 EXPECT_TRUE(t2Evb->isInEventBaseThread());
112 cerr << "client SSL socket connected" << endl;
113 for (int i = 0; i < 1000; ++i) {
114 sslSocket_->detachSSLContext();
115 sslSocket_->attachSSLContext(t2_.ctx_);
118 // detach from t2 and then read/write in t1
119 t2Evb->runInEventBaseThread([this] {
121 auto t1Evb = t1_.getEventBase();
122 t1Evb->runInEventBaseThread([this] {
123 t1_.attach(*sslSocket_);
124 sslSocket_->write(this, buf_, sizeof(buf_));
125 sslSocket_->setReadCB(this);
126 memset(readbuf_, 'b', sizeof(readbuf_));
132 void connectErr(const AsyncSocketException& ex) noexcept override
134 cerr << "AttachDetachClient::connectError: " << ex.what() << endl;
138 void writeSuccess() noexcept override {
139 cerr << "client write success" << endl;
142 void writeErr(size_t /* bytesWritten */,
143 const AsyncSocketException& ex) noexcept override {
144 cerr << "client writeError: " << ex.what() << endl;
147 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
148 *bufReturn = readbuf_ + bytesRead_;
149 *lenReturn = sizeof(readbuf_) - bytesRead_;
151 void readEOF() noexcept override {
152 cerr << "client readEOF" << endl;
155 void readErr(const AsyncSocketException& ex) noexcept override {
156 cerr << "client readError: " << ex.what() << endl;
157 promise_.setException(ex);
160 void readDataAvailable(size_t len) noexcept override {
161 EXPECT_TRUE(t1_.getEventBase()->isInEventBaseThread());
162 EXPECT_EQ(sslSocket_->getEventBase(), t1_.getEventBase());
163 cerr << "client read data: " << len << endl;
165 if (len == sizeof(buf_)) {
166 EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
167 sslSocket_->closeNow();
169 promise_.setValue(true);
175 * Test passing contexts between threads
177 TEST(AsyncSSLSocketTest2, AttachDetachSSLContext) {
178 // Start listening on a local port
179 WriteCallbackBase writeCallback;
180 ReadCallback readCallback(&writeCallback);
181 HandshakeCallback handshakeCallback(&readCallback);
182 SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
183 TestSSLServer server(&acceptCallback);
185 std::shared_ptr<AttachDetachClient> client(
186 new AttachDetachClient(server.getAddress()));
188 auto f = client->getFuture();
190 EXPECT_TRUE(f.within(std::chrono::seconds(3)).get());
193 class ConnectClient : public AsyncSocket::ConnectCallback {
195 ConnectClient() = default;
197 Future<bool> getFuture() {
198 return promise_.getFuture();
201 void connect(const folly::SocketAddress& addr) {
202 t1_.getEventBase()->runInEventBaseThread([&] {
203 socket_ = t1_.createSocket();
204 socket_->connect(this, addr);
208 void connectSuccess() noexcept override {
209 promise_.setValue(true);
213 void connectErr(const AsyncSocketException& /* ex */) noexcept override {
214 promise_.setValue(false);
218 void setCtx(std::shared_ptr<SSLContext> ctx) {
224 // promise to fulfill when done with a value of true if connect succeeded
225 folly::Promise<bool> promise_;
226 std::shared_ptr<AsyncSSLSocket> socket_;
229 class NoopReadCallback : public ReadCallbackBase {
231 NoopReadCallback() : ReadCallbackBase(nullptr) {
232 state = STATE_SUCCEEDED;
235 void getReadBuffer(void** buf, size_t* lenReturn) override {
239 void readDataAvailable(size_t) noexcept override {}
244 TEST(AsyncSSLSocketTest2, TestTLS12DefaultClient) {
245 // Start listening on a local port
246 NoopReadCallback readCallback;
247 HandshakeCallback handshakeCallback(&readCallback);
248 SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
249 auto ctx = std::make_shared<SSLContext>(SSLContext::TLSv1_2);
250 TestSSLServer server(&acceptCallback, ctx);
251 server.loadTestCerts();
253 // create a default client
254 auto c1 = std::make_unique<ConnectClient>();
255 auto f1 = c1->getFuture();
256 c1->connect(server.getAddress());
257 EXPECT_TRUE(f1.within(std::chrono::seconds(3)).get());
260 TEST(AsyncSSLSocketTest2, TestTLS12BadClient) {
261 // Start listening on a local port
262 NoopReadCallback readCallback;
263 HandshakeCallback handshakeCallback(
264 &readCallback, HandshakeCallback::EXPECT_ERROR);
265 SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
266 auto ctx = std::make_shared<SSLContext>(SSLContext::TLSv1_2);
267 TestSSLServer server(&acceptCallback, ctx);
268 server.loadTestCerts();
270 // create a client that doesn't speak TLS 1.2
271 auto c2 = std::make_unique<ConnectClient>();
272 auto clientCtx = std::make_shared<SSLContext>();
273 clientCtx->setOptions(SSL_OP_NO_TLSv1_2);
274 c2->setCtx(clientCtx);
275 auto f2 = c2->getFuture();
276 c2->connect(server.getAddress());
277 EXPECT_FALSE(f2.within(std::chrono::seconds(3)).get());
282 int main(int argc, char *argv[]) {
284 signal(SIGPIPE, SIG_IGN);
286 testing::InitGoogleTest(&argc, argv);
287 folly::init(&argc, &argv);
288 return RUN_ALL_TESTS();