2 * Copyright 2017 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 #include <folly/SocketAddress.h>
19 #include <folly/experimental/TestUtil.h>
20 #include <folly/io/async/AsyncSSLSocket.h>
21 #include <folly/io/async/AsyncServerSocket.h>
22 #include <folly/io/async/AsyncSocket.h>
23 #include <folly/io/async/AsyncTimeout.h>
24 #include <folly/io/async/AsyncTransport.h>
25 #include <folly/io/async/EventBase.h>
26 #include <folly/io/async/ssl/SSLErrors.h>
27 #include <folly/portability/GTest.h>
28 #include <folly/portability/Sockets.h>
29 #include <folly/portability/Unistd.h>
32 #include <sys/types.h>
37 extern const char* kTestCert;
38 extern const char* kTestKey;
39 extern const char* kTestCA;
41 extern const char* kClientTestCert;
42 extern const char* kClientTestKey;
43 extern const char* kClientTestCA;
45 enum StateEnum { STATE_WAITING, STATE_SUCCEEDED, STATE_FAILED };
47 class HandshakeCallback;
49 class SSLServerAcceptCallbackBase : public AsyncServerSocket::AcceptCallback {
51 explicit SSLServerAcceptCallbackBase(HandshakeCallback* hcb)
52 : state(STATE_WAITING), hcb_(hcb) {}
54 ~SSLServerAcceptCallbackBase() override {
55 EXPECT_EQ(STATE_SUCCEEDED, state);
58 void acceptError(const std::exception& ex) noexcept override {
59 LOG(WARNING) << "SSLServerAcceptCallbackBase::acceptError " << ex.what();
63 void connectionAccepted(
65 const SocketAddress& /* clientAddr */) noexcept override {
67 socket_->detachEventBase();
69 LOG(INFO) << "Connection accepted";
71 // Create a AsyncSSLSocket object with the fd. The socket should be
72 // added to the event base and in the state of accepting SSL connection.
73 socket_ = AsyncSSLSocket::newSocket(ctx_, base_, fd);
74 } catch (const std::exception& e) {
75 LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
83 connAccepted(socket_);
86 virtual void connAccepted(const std::shared_ptr<AsyncSSLSocket>& s) = 0;
89 socket_->detachEventBase();
93 HandshakeCallback* hcb_;
94 std::shared_ptr<SSLContext> ctx_;
95 std::shared_ptr<AsyncSSLSocket> socket_;
101 // Create a TestSSLServer.
102 // This immediately starts listening on the given port.
103 explicit TestSSLServer(
104 SSLServerAcceptCallbackBase* acb,
105 bool enableTFO = false);
106 explicit TestSSLServer(
107 SSLServerAcceptCallbackBase* acb,
108 std::shared_ptr<SSLContext> ctx,
109 bool enableTFO = false);
112 virtual ~TestSSLServer();
114 EventBase& getEventBase() {
118 void loadTestCerts();
120 const SocketAddress& getAddress() const {
126 std::shared_ptr<SSLContext> ctx_;
127 SSLServerAcceptCallbackBase* acb_;
128 std::shared_ptr<AsyncServerSocket> socket_;
129 SocketAddress address_;