#include <cds_test/stress_test.h>
#include <iostream>
#include <iostream>
+#include <memory>
#include <thread>
using namespace std;
#define TASK(lock_type, lock_ptr, pass_cnt) \
static void Thread##lock_type() { \
- for (int i = 0; i < pass_cnt; i++) { \
- for (int j = 0; j < pass_cnt; j++) { \
- lock_ptr->lock(); \
- x = i + j; \
- lock_ptr->unlock(); \
- } \
+ for (size_t i = 0; i < pass_cnt; i++) { \
+ lock_ptr->lock(); \
+ x++; \
+ lock_ptr->unlock(); \
} \
}
-#define LOCK_TEST(lock_type, lock_ptr) \
+#define LOCK_TEST(lock_type, lock_ptr, pass_cnt) \
TEST_F(SpinLockTest, lock_type) { \
lock_ptr = new lock_type(); \
+ x = 0; \
+ std::unique_ptr<std::thread[]>(new std::thread[s_nSpinLockThreadCount]); \
std::thread *threads = new std::thread[s_nSpinLockThreadCount]; \
- for (int i = 0; i < s_nSpinLockThreadCount; i++) { \
+ for (size_t i = 0; i < s_nSpinLockThreadCount; i++) { \
threads[i] = std::thread(Thread##lock_type); \
} \
- for (int i = 0; i < s_nSpinLockThreadCount; i++) { \
+ for (size_t i = 0; i < s_nSpinLockThreadCount; i++) { \
threads[i].join(); \
} \
+ if (x != s_nSpinLockThreadCount * pass_cnt) { \
+ cout << "Incorrect " << #lock_type << "\n"; \
+ cout << "x=" << x << "\nThreadCount=" << s_nSpinLockThreadCount \
+ << "\nPassCount=" << pass_cnt \
+ << "\t&&\tSupposed times=" << s_nSpinLockThreadCount * pass_cnt \
+ << "\n"; \
+ } \
}
class SpinLockTest : public cds_test::stress_fixture {
protected:
- static int x;
+ static size_t x;
static TicketLock *ticket_mutex;
static SpinLock *spin_mutex;
static Reentrant32 *reentrant_mutex32;
TASK(Reentrant64, reentrant_mutex64, s_nSpinLockPassCount)
};
-int SpinLockTest::x;
+size_t SpinLockTest::x;
TicketLock *SpinLockTest::ticket_mutex;
SpinLock *SpinLockTest::spin_mutex;
Reentrant32 *SpinLockTest::reentrant_mutex32;
Reentrant64 *SpinLockTest::reentrant_mutex64;
-LOCK_TEST(TicketLock, ticket_mutex)
-LOCK_TEST(SpinLock, spin_mutex)
-LOCK_TEST(Reentrant32, reentrant_mutex32)
-LOCK_TEST(Reentrant64, reentrant_mutex64)
+LOCK_TEST(TicketLock, ticket_mutex, s_nTicketLockPassCount)
+LOCK_TEST(SpinLock, spin_mutex, s_nSpinLockPassCount)
+LOCK_TEST(Reentrant32, reentrant_mutex32, s_nSpinLockPassCount)
+LOCK_TEST(Reentrant64, reentrant_mutex64, s_nSpinLockPassCount)
} // namespace