2 * Copyright 2014 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 // @author: Xin Liu <xliux@fb.com>
23 #include <system_error>
25 #include <glog/logging.h>
26 #include <gflags/gflags.h>
27 #include <folly/ConcurrentSkipList.h>
28 #include <folly/Foreach.h>
29 #include <folly/String.h>
30 #include <gtest/gtest.h>
32 DEFINE_int32(num_threads, 12, "num concurrent threads to test");
36 using namespace folly;
39 typedef int ValueType;
40 typedef detail::SkipListNode<ValueType> SkipListNodeType;
41 typedef ConcurrentSkipList<ValueType> SkipListType;
42 typedef SkipListType::Accessor SkipListAccessor;
43 typedef vector<ValueType> VectorType;
44 typedef std::set<ValueType> SetType;
46 static const int kHeadHeight = 2;
47 static const int kMaxValue = 5000;
49 static void randomAdding(int size,
50 SkipListAccessor skipList,
52 int maxValue = kMaxValue) {
53 for (int i = 0; i < size; ++i) {
54 int32_t r = rand() % maxValue;
60 static void randomRemoval(int size,
61 SkipListAccessor skipList,
63 int maxValue=kMaxValue) {
64 for (int i = 0; i < size; ++i) {
65 int32_t r = rand() % maxValue;
71 static void sumAllValues(SkipListAccessor skipList, int64_t *sum) {
73 FOR_EACH(it, skipList) {
76 VLOG(20) << "sum = " << sum;
79 static void concurrentSkip(const vector<ValueType> *values,
80 SkipListAccessor skipList) {
82 SkipListAccessor::Skipper skipper(skipList);
83 FOR_EACH(it, *values) {
84 if (skipper.to(*it)) sum += *it;
86 VLOG(20) << "sum = " << sum;
89 bool verifyEqual(SkipListAccessor skipList,
90 const SetType &verifier) {
91 EXPECT_EQ(verifier.size(), skipList.size());
92 FOR_EACH(it, verifier) {
93 CHECK(skipList.contains(*it)) << *it;
94 SkipListType::const_iterator iter = skipList.find(*it);
95 CHECK(iter != skipList.end());
96 EXPECT_EQ(*iter, *it);
98 EXPECT_TRUE(std::equal(verifier.begin(), verifier.end(), skipList.begin()));
102 TEST(ConcurrentSkipList, SequentialAccess) {
104 LOG(INFO) << "nodetype size=" << sizeof(SkipListNodeType);
106 auto skipList(SkipListType::create(kHeadHeight));
107 EXPECT_TRUE(skipList.first() == nullptr);
108 EXPECT_TRUE(skipList.last() == nullptr);
111 EXPECT_TRUE(skipList.contains(3));
112 EXPECT_FALSE(skipList.contains(2));
113 EXPECT_EQ(3, *skipList.first());
114 EXPECT_EQ(3, *skipList.last());
116 EXPECT_EQ(3, *skipList.find(3));
117 EXPECT_FALSE(skipList.find(3) == skipList.end());
118 EXPECT_TRUE(skipList.find(2) == skipList.end());
121 SkipListAccessor::Skipper skipper(skipList);
123 CHECK_EQ(3, *skipper);
127 EXPECT_EQ(2, *skipList.first());
128 EXPECT_EQ(3, *skipList.last());
130 EXPECT_EQ(5, *skipList.last());
132 EXPECT_EQ(5, *skipList.last());
133 auto ret = skipList.insert(9);
134 EXPECT_EQ(9, *ret.first);
135 EXPECT_TRUE(ret.second);
137 ret = skipList.insert(5);
138 EXPECT_EQ(5, *ret.first);
139 EXPECT_FALSE(ret.second);
141 EXPECT_EQ(2, *skipList.first());
142 EXPECT_EQ(9, *skipList.last());
143 EXPECT_TRUE(skipList.pop_back());
144 EXPECT_EQ(5, *skipList.last());
145 EXPECT_TRUE(skipList.pop_back());
146 EXPECT_EQ(3, *skipList.last());
151 CHECK(skipList.contains(2));
152 CHECK(skipList.contains(3));
153 CHECK(skipList.contains(5));
154 CHECK(skipList.contains(9));
155 CHECK(!skipList.contains(4));
158 auto it = skipList.lower_bound(5);
160 it = skipList.lower_bound(4);
162 it = skipList.lower_bound(9);
164 it = skipList.lower_bound(12);
165 EXPECT_FALSE(it.good());
167 it = skipList.begin();
171 SkipListAccessor::Skipper skipper(skipList);
173 EXPECT_EQ(3, skipper.data());
175 EXPECT_EQ(5, skipper.data());
176 CHECK(!skipper.to(7));
180 CHECK(skipper.to(9));
181 EXPECT_EQ(9, skipper.data());
183 CHECK(!skipList.contains(3));
185 CHECK(skipList.contains(3));
187 FOR_EACH(it, skipList) {
188 LOG(INFO) << "pos= " << pos++ << " value= " << *it;
193 auto skipList(SkipListType::create(kHeadHeight));
196 randomAdding(10000, skipList, &verifier);
197 verifyEqual(skipList, verifier);
200 SkipListAccessor::Skipper skipper(skipList);
201 int num_skips = 1000;
202 for (int i = 0; i < num_skips; ++i) {
203 int n = i * kMaxValue / num_skips;
204 bool found = skipper.to(n);
205 EXPECT_EQ(found, (verifier.find(n) != verifier.end()));
211 static std::string makeRandomeString(int len) {
213 for (int j = 0; j < len; j++) {
214 s.push_back((rand() % 26) + 'A');
219 TEST(ConcurrentSkipList, TestStringType) {
220 typedef folly::ConcurrentSkipList<std::string> SkipListT;
221 std::shared_ptr<SkipListT> skip = SkipListT::createInstance();
222 SkipListT::Accessor accessor(skip);
224 for (int i = 0; i < 100000; i++) {
225 std::string s = makeRandomeString(7);
229 EXPECT_TRUE(std::is_sorted(accessor.begin(), accessor.end()));
232 struct UniquePtrComp {
234 const std::unique_ptr<int> &x, const std::unique_ptr<int> &y) const {
235 if (!x) return false;
241 TEST(ConcurrentSkipList, TestMovableData) {
242 typedef folly::ConcurrentSkipList<std::unique_ptr<int>, UniquePtrComp>
244 auto sl = SkipListT::createInstance() ;
245 SkipListT::Accessor accessor(sl);
247 static const int N = 10;
248 for (int i = 0; i < N; ++i) {
249 accessor.insert(std::unique_ptr<int>(new int(i)));
252 for (int i = 0; i < N; ++i) {
253 EXPECT_TRUE(accessor.find(std::unique_ptr<int>(new int(i))) !=
256 EXPECT_TRUE(accessor.find(std::unique_ptr<int>(new int(N))) ==
260 void testConcurrentAdd(int numThreads) {
261 auto skipList(SkipListType::create(kHeadHeight));
263 vector<std::thread> threads;
264 vector<SetType> verifiers(numThreads);
266 for (int i = 0; i < numThreads; ++i) {
267 threads.push_back(std::thread(
268 &randomAdding, 100, skipList, &verifiers[i], kMaxValue));
270 } catch (const std::system_error& e) {
272 << "Caught " << exceptionStr(e)
273 << ": could only create " << threads.size() << " threads out of "
276 for (int i = 0; i < threads.size(); ++i) {
281 FOR_EACH(s, verifiers) {
282 all.insert(s->begin(), s->end());
284 verifyEqual(skipList, all);
287 TEST(ConcurrentSkipList, ConcurrentAdd) {
288 // test it many times
289 for (int numThreads = 10; numThreads < 10000; numThreads += 1000) {
290 testConcurrentAdd(numThreads);
294 void testConcurrentRemoval(int numThreads, int maxValue) {
295 auto skipList = SkipListType::create(kHeadHeight);
296 for (int i = 0; i < maxValue; ++i) {
300 vector<std::thread> threads;
301 vector<SetType > verifiers(numThreads);
303 for (int i = 0; i < numThreads; ++i) {
304 threads.push_back(std::thread(
305 &randomRemoval, 100, skipList, &verifiers[i], maxValue));
307 } catch (const std::system_error& e) {
309 << "Caught " << exceptionStr(e)
310 << ": could only create " << threads.size() << " threads out of "
313 FOR_EACH(t, threads) {
318 FOR_EACH(s, verifiers) {
319 all.insert(s->begin(), s->end());
322 CHECK_EQ(maxValue, all.size() + skipList.size());
323 for (int i = 0; i < maxValue; ++i) {
324 if (all.find(i) != all.end()) {
325 CHECK(!skipList.contains(i)) << i;
327 CHECK(skipList.contains(i)) << i;
332 TEST(ConcurrentSkipList, ConcurrentRemove) {
333 for (int numThreads = 10; numThreads < 1000; numThreads += 100) {
334 testConcurrentRemoval(numThreads, 100 * numThreads);
338 static void testConcurrentAccess(
339 int numInsertions, int numDeletions, int maxValue) {
340 auto skipList = SkipListType::create(kHeadHeight);
342 vector<SetType> verifiers(FLAGS_num_threads);
343 vector<int64_t> sums(FLAGS_num_threads);
344 vector<vector<ValueType> > skipValues(FLAGS_num_threads);
346 for (int i = 0; i < FLAGS_num_threads; ++i) {
347 for (int j = 0; j < numInsertions; ++j) {
348 skipValues[i].push_back(rand() % (maxValue + 1));
350 std::sort(skipValues[i].begin(), skipValues[i].end());
353 vector<std::thread> threads;
354 for (int i = 0; i < FLAGS_num_threads; ++i) {
358 threads.push_back(std::thread(
359 randomAdding, numInsertions, skipList, &verifiers[i], maxValue));
362 threads.push_back(std::thread(
363 randomRemoval, numDeletions, skipList, &verifiers[i], maxValue));
366 threads.push_back(std::thread(
367 concurrentSkip, &skipValues[i], skipList));
370 threads.push_back(std::thread(sumAllValues, skipList, &sums[i]));
375 FOR_EACH(t, threads) {
378 // just run through it, no need to verify the correctness.
381 TEST(ConcurrentSkipList, ConcurrentAccess) {
382 testConcurrentAccess(10000, 100, kMaxValue);
383 testConcurrentAccess(100000, 10000, kMaxValue * 10);
384 testConcurrentAccess(1000000, 100000, kMaxValue);
389 int main(int argc, char* argv[]) {
390 testing::InitGoogleTest(&argc, argv);
391 google::InitGoogleLogging(argv[0]);
392 gflags::ParseCommandLineFlags(&argc, &argv, true);
394 return RUN_ALL_TESTS();