2 * Copyright 2016 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
16 #include <folly/Foreach.h>
17 #include <folly/io/Cursor.h>
18 #include <folly/io/async/AsyncSSLSocket.h>
19 #include <folly/io/async/AsyncSocket.h>
20 #include <folly/io/async/EventBase.h>
22 #include <gtest/gtest.h>
23 #include <gmock/gmock.h>
28 using namespace testing;
32 class MockAsyncSSLSocket : public AsyncSSLSocket{
34 static std::shared_ptr<MockAsyncSSLSocket> newSocket(
35 const std::shared_ptr<SSLContext>& ctx,
37 auto sock = std::shared_ptr<MockAsyncSSLSocket>(
38 new MockAsyncSSLSocket(ctx, evb),
40 sock->ssl_ = SSL_new(ctx->getSSLCtx());
41 SSL_set_fd(sock->ssl_, -1);
45 // Fake constructor sets the state to established without call to connect
47 MockAsyncSSLSocket(const std::shared_ptr<SSLContext>& ctx,
49 : AsyncSocket(evb), AsyncSSLSocket(ctx, evb) {
50 state_ = AsyncSocket::StateEnum::ESTABLISHED;
51 sslState_ = AsyncSSLSocket::SSLStateEnum::STATE_ESTABLISHED;
54 // mock the calls to SSL_write to see the buffer length and contents
55 MOCK_METHOD3(sslWriteImpl, int(SSL *ssl, const void *buf, int n));
57 // mock the calls to getRawBytesWritten()
58 MOCK_CONST_METHOD0(getRawBytesWritten, size_t());
60 // public wrapper for protected interface
61 ssize_t testPerformWrite(const iovec* vec, uint32_t count, WriteFlags flags,
62 uint32_t* countWritten, uint32_t* partialWritten) {
63 return performWrite(vec, count, flags, countWritten, partialWritten);
66 void checkEor(size_t appEor, size_t rawEor) {
67 EXPECT_EQ(appEor, appEorByteNo_);
68 EXPECT_EQ(rawEor, minEorRawByteNo_);
71 void setAppBytesWritten(size_t n) {
76 class AsyncSSLSocketWriteTest : public testing::Test {
78 AsyncSSLSocketWriteTest() :
79 sslContext_(new SSLContext()),
80 sock_(MockAsyncSSLSocket::newSocket(sslContext_, &eventBase_)) {
81 for (int i = 0; i < 500; i++) {
82 memcpy(source_ + i * 26, "abcdefghijklmnopqrstuvwxyz", 26);
86 // Make an iovec containing chunks of the reference text with requested sizes
88 std::unique_ptr<iovec[]> makeVec(std::vector<uint32_t> sizes) {
89 std::unique_ptr<iovec[]> vec(new iovec[sizes.size()]);
92 for (auto size: sizes) {
93 vec[i].iov_base = (void *)(source_ + pos);
94 vec[i++].iov_len = size;
100 // Verify that the given buf/pos matches the reference text
101 void verifyVec(const void *buf, int n, int pos) {
102 ASSERT_EQ(memcmp(source_ + pos, buf, n), 0);
105 // Update a vec on partial write
106 void consumeVec(iovec *vec, uint32_t countWritten, uint32_t partialWritten) {
107 vec[countWritten].iov_base =
108 ((char *)vec[countWritten].iov_base) + partialWritten;
109 vec[countWritten].iov_len -= partialWritten;
112 EventBase eventBase_;
113 std::shared_ptr<SSLContext> sslContext_;
114 std::shared_ptr<MockAsyncSSLSocket> sock_;
115 char source_[26 * 500];
119 // The entire vec fits in one packet
120 TEST_F(AsyncSSLSocketWriteTest, write_coalescing1) {
122 auto vec = makeVec({3, 3, 3});
123 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 9))
124 .WillOnce(Invoke([this] (SSL *, const void *buf, int n) {
125 verifyVec(buf, n, 0);
127 uint32_t countWritten = 0;
128 uint32_t partialWritten = 0;
129 sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
131 EXPECT_EQ(countWritten, n);
132 EXPECT_EQ(partialWritten, 0);
135 // First packet is full, second two go in one packet
136 TEST_F(AsyncSSLSocketWriteTest, write_coalescing2) {
138 auto vec = makeVec({1500, 3, 3});
140 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
141 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
142 verifyVec(buf, n, pos);
145 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
146 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
147 verifyVec(buf, n, pos);
150 uint32_t countWritten = 0;
151 uint32_t partialWritten = 0;
152 sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
154 EXPECT_EQ(countWritten, n);
155 EXPECT_EQ(partialWritten, 0);
158 // Two exactly full packets (coalesce ends midway through second chunk)
159 TEST_F(AsyncSSLSocketWriteTest, write_coalescing3) {
161 auto vec = makeVec({1000, 1000, 1000});
163 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
165 .WillRepeatedly(Invoke([this, &pos] (SSL *, const void *buf, int n) {
166 verifyVec(buf, n, pos);
169 uint32_t countWritten = 0;
170 uint32_t partialWritten = 0;
171 sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
173 EXPECT_EQ(countWritten, n);
174 EXPECT_EQ(partialWritten, 0);
177 // Partial write success midway through a coalesced vec
178 TEST_F(AsyncSSLSocketWriteTest, write_coalescing4) {
180 auto vec = makeVec({300, 300, 300, 300, 300});
182 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
183 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
184 verifyVec(buf, n, pos);
186 return 1000; /* 500 bytes "pending" */ }));
187 uint32_t countWritten = 0;
188 uint32_t partialWritten = 0;
189 sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
191 EXPECT_EQ(countWritten, 3);
192 EXPECT_EQ(partialWritten, 100);
193 consumeVec(vec.get(), countWritten, partialWritten);
194 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
195 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
196 verifyVec(buf, n, pos);
199 sock_->testPerformWrite(vec.get() + countWritten, n - countWritten,
201 &countWritten, &partialWritten);
202 EXPECT_EQ(countWritten, 2);
203 EXPECT_EQ(partialWritten, 0);
206 // coalesce ends exactly on a buffer boundary
207 TEST_F(AsyncSSLSocketWriteTest, write_coalescing5) {
209 auto vec = makeVec({1000, 500, 500});
211 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
212 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
213 verifyVec(buf, n, pos);
216 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
217 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
218 verifyVec(buf, n, pos);
221 uint32_t countWritten = 0;
222 uint32_t partialWritten = 0;
223 sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
225 EXPECT_EQ(countWritten, 3);
226 EXPECT_EQ(partialWritten, 0);
229 // partial write midway through first chunk
230 TEST_F(AsyncSSLSocketWriteTest, write_coalescing6) {
232 auto vec = makeVec({1000, 500});
234 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
235 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
236 verifyVec(buf, n, pos);
239 uint32_t countWritten = 0;
240 uint32_t partialWritten = 0;
241 sock_->testPerformWrite(vec.get(), n, WriteFlags::NONE, &countWritten,
243 EXPECT_EQ(countWritten, 0);
244 EXPECT_EQ(partialWritten, 700);
245 consumeVec(vec.get(), countWritten, partialWritten);
246 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 800))
247 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
248 verifyVec(buf, n, pos);
251 sock_->testPerformWrite(vec.get() + countWritten, n - countWritten,
253 &countWritten, &partialWritten);
254 EXPECT_EQ(countWritten, 2);
255 EXPECT_EQ(partialWritten, 0);
258 // Repeat coalescing2 with WriteFlags::EOR
259 TEST_F(AsyncSSLSocketWriteTest, write_with_eor1) {
261 auto vec = makeVec({1500, 3, 3});
263 const size_t initAppBytesWritten = 500;
264 const size_t appEor = initAppBytesWritten + 1506;
266 sock_->setAppBytesWritten(initAppBytesWritten);
267 EXPECT_FALSE(sock_->isEorTrackingEnabled());
268 sock_->setEorTracking(true);
269 EXPECT_TRUE(sock_->isEorTrackingEnabled());
271 EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
272 // rawBytesWritten after writting initAppBytesWritten + 1500
273 // + some random SSL overhead
274 .WillOnce(Return(3600))
275 // rawBytesWritten after writting last 6 bytes
276 // + some random SSL overhead
277 .WillOnce(Return(3728));
278 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
279 .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
280 // the first 1500 does not have the EOR byte
281 sock_->checkEor(0, 0);
282 verifyVec(buf, n, pos);
285 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
286 .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
287 sock_->checkEor(appEor, 3600 + n);
288 verifyVec(buf, n, pos);
292 uint32_t countWritten = 0;
293 uint32_t partialWritten = 0;
294 sock_->testPerformWrite(vec.get(), n , WriteFlags::EOR,
295 &countWritten, &partialWritten);
296 EXPECT_EQ(countWritten, n);
297 EXPECT_EQ(partialWritten, 0);
298 sock_->checkEor(0, 0);
301 // coalescing with left over at the last chunk
302 // WriteFlags::EOR turned on
303 TEST_F(AsyncSSLSocketWriteTest, write_with_eor2) {
305 auto vec = makeVec({600, 600, 600});
307 const size_t initAppBytesWritten = 500;
308 const size_t appEor = initAppBytesWritten + 1800;
310 sock_->setAppBytesWritten(initAppBytesWritten);
311 sock_->setEorTracking(true);
313 EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
314 // rawBytesWritten after writting initAppBytesWritten + 1500 bytes
315 // + some random SSL overhead
316 .WillOnce(Return(3600))
317 // rawBytesWritten after writting last 300 bytes
318 // + some random SSL overhead
319 .WillOnce(Return(4100));
320 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
321 .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
322 // the first 1500 does not have the EOR byte
323 sock_->checkEor(0, 0);
324 verifyVec(buf, n, pos);
327 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 300))
328 .WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
329 sock_->checkEor(appEor, 3600 + n);
330 verifyVec(buf, n, pos);
334 uint32_t countWritten = 0;
335 uint32_t partialWritten = 0;
336 sock_->testPerformWrite(vec.get(), n, WriteFlags::EOR,
337 &countWritten, &partialWritten);
338 EXPECT_EQ(countWritten, n);
339 EXPECT_EQ(partialWritten, 0);
340 sock_->checkEor(0, 0);
343 // WriteFlags::EOR set
345 // Partial write at 1000-th byte
346 TEST_F(AsyncSSLSocketWriteTest, write_with_eor3) {
348 auto vec = makeVec({1600});
350 const size_t initAppBytesWritten = 500;
351 const size_t appEor = initAppBytesWritten + 1600;
353 sock_->setAppBytesWritten(initAppBytesWritten);
354 sock_->setEorTracking(true);
356 EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
357 // rawBytesWritten after the initAppBytesWritten
358 // + some random SSL overhead
359 .WillOnce(Return(2000))
360 // rawBytesWritten after the initAppBytesWritten + 1000 (with 100 overhead)
361 // + some random SSL overhead
362 .WillOnce(Return(3100));
363 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1600))
364 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
365 sock_->checkEor(appEor, 2000 + n);
366 verifyVec(buf, n, pos);
370 uint32_t countWritten = 0;
371 uint32_t partialWritten = 0;
372 sock_->testPerformWrite(vec.get(), n, WriteFlags::EOR,
373 &countWritten, &partialWritten);
374 EXPECT_EQ(countWritten, 0);
375 EXPECT_EQ(partialWritten, 1000);
376 sock_->checkEor(appEor, 2000 + 1600);
377 consumeVec(vec.get(), countWritten, partialWritten);
379 EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
380 .WillOnce(Return(3100))
381 .WillOnce(Return(3800));
382 EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 600))
383 .WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
384 sock_->checkEor(appEor, 3100 + n);
385 verifyVec(buf, n, pos);
388 sock_->testPerformWrite(vec.get() + countWritten, n - countWritten,
390 &countWritten, &partialWritten);
391 EXPECT_EQ(countWritten, n);
392 EXPECT_EQ(partialWritten, 0);
393 sock_->checkEor(0, 0);