2 * Copyright 2017 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
16 #include <folly/io/ShutdownSocketSet.h>
22 #include <glog/logging.h>
24 #include <folly/portability/GTest.h>
25 #include <folly/portability/Sockets.h>
27 using folly::ShutdownSocketSet;
29 namespace fsp = folly::portability::sockets;
34 ShutdownSocketSet shutdownSocketSet;
40 void stop(bool abortive);
42 int port() const { return port_; }
43 int closeClients(bool abortive);
53 std::atomic<StopMode> stop_;
54 std::thread serverThread_;
55 std::vector<int> fds_;
62 acceptSocket_ = fsp::socket(PF_INET, SOCK_STREAM, 0);
63 CHECK_ERR(acceptSocket_);
64 shutdownSocketSet.add(acceptSocket_);
67 addr.sin_family = AF_INET;
69 addr.sin_addr.s_addr = INADDR_ANY;
70 CHECK_ERR(bind(acceptSocket_,
71 reinterpret_cast<const sockaddr*>(&addr),
74 CHECK_ERR(listen(acceptSocket_, 10));
76 socklen_t addrLen = sizeof(addr);
77 CHECK_ERR(getsockname(acceptSocket_,
78 reinterpret_cast<sockaddr*>(&addr),
81 port_ = ntohs(addr.sin_port);
83 serverThread_ = std::thread([this] {
84 while (stop_ == NO_STOP) {
86 socklen_t peerLen = sizeof(peer);
87 int fd = accept(acceptSocket_,
88 reinterpret_cast<sockaddr*>(&peer),
94 if (errno == EINVAL || errno == ENOTSOCK) { // socket broken
99 shutdownSocketSet.add(fd);
103 if (stop_ != NO_STOP) {
104 closeClients(stop_ == ABORTIVE);
107 shutdownSocketSet.close(acceptSocket_);
113 int Server::closeClients(bool abortive) {
114 for (int fd : fds_) {
116 struct linger l = {1, 0};
117 CHECK_ERR(setsockopt(fd, SOL_SOCKET, SO_LINGER, &l, sizeof(l)));
119 shutdownSocketSet.close(fd);
126 void Server::stop(bool abortive) {
127 stop_ = abortive ? ABORTIVE : ORDERLY;
128 shutdown(acceptSocket_, SHUT_RDWR);
131 void Server::join() {
132 serverThread_.join();
135 int createConnectedSocket(int port) {
136 int sock = fsp::socket(PF_INET, SOCK_STREAM, 0);
139 addr.sin_family = AF_INET;
140 addr.sin_port = htons(port);
141 addr.sin_addr.s_addr = htonl((127 << 24) | 1); // XXX
142 CHECK_ERR(connect(sock,
143 reinterpret_cast<const sockaddr*>(&addr),
148 void runCloseTest(bool abortive) {
151 int sock = createConnectedSocket(server.port());
153 std::thread stopper([&server, abortive] {
154 std::this_thread::sleep_for(std::chrono::milliseconds(200));
155 server.stop(abortive);
160 int r = read(sock, &c, 1);
164 EXPECT_EQ(ECONNRESET, e);
173 EXPECT_EQ(0, server.closeClients(false)); // closed by server when it exited
176 TEST(ShutdownSocketSetTest, OrderlyClose) {
180 TEST(ShutdownSocketSetTest, AbortiveClose) {
184 void runKillTest(bool abortive) {
187 int sock = createConnectedSocket(server.port());
189 std::thread killer([&server, abortive] {
190 std::this_thread::sleep_for(std::chrono::milliseconds(200));
191 shutdownSocketSet.shutdownAll(abortive);
196 int r = read(sock, &c, 1);
198 // "abortive" is just a hint for ShutdownSocketSet, so accept both
202 EXPECT_EQ(ECONNRESET, errno);
214 // NOT closed by server when it exited
215 EXPECT_EQ(1, server.closeClients(false));
218 TEST(ShutdownSocketSetTest, OrderlyKill) {
222 TEST(ShutdownSocketSetTest, AbortiveKill) {