Fixed -Wshadow warnings
[libcds.git] / test / include / cds_test / thread.h
index 4f813bd52b0defd2d4e3da889ec61715920dfe77..51fc399a200ca3157a06be83d2a96cbb681c60a4 100644 (file)
@@ -1,7 +1,7 @@
 /*
     This file is a part of libcds - Concurrent Data Structures library
 
-    (C) Copyright Maxim Khizhinsky (libcds.dev@gmail.com) 2006-2016
+    (C) Copyright Maxim Khizhinsky (libcds.dev@gmail.com) 2006-2017
 
     Source code repo: http://github.com/khizmax/libcds/
     Download: http://sourceforge.net/projects/libcds/files/
 #ifndef CDSTEST_THREAD_H
 #define CDSTEST_THREAD_H
 
-#include <gtest/gtest.h>
+#include <cds_test/ext_gtest.h>
 #include <vector>
 #include <thread>
 #include <condition_variable>
 #include <mutex>
 #include <chrono>
+#include <cds/threading/model.h>
 
 namespace cds_test {
 
@@ -45,7 +46,7 @@ namespace cds_test {
     class thread_pool;
 
     // Test thread
-    class thread 
+    class thread
     {
         void run();
 
@@ -55,93 +56,183 @@ namespace cds_test {
         virtual ~thread()
         {}
 
-        void join()         { m_impl.join(); }
-
     protected:
         virtual thread * clone() = 0;
         virtual void test() = 0;
 
         virtual void SetUp()
-        {}
+        {
+            cds::threading::Manager::attachThread();
+        }
+
         virtual void TearDown()
-        {}
+        {
+            cds::threading::Manager::detachThread();
+        }
 
     public:
         explicit thread( thread_pool& master, int type = 0 );
-        
+
         thread_pool& pool() { return m_pool; }
         int type() const { return m_type; }
         size_t id() const { return m_id;  }
+        bool time_elapsed() const;
 
     private:
         friend class thread_pool;
 
         thread_pool&    m_pool;
-        int             m_type;
-        size_t          m_id;
-        std::thread     m_impl;
+        int const       m_type;
+        size_t const    m_id;
     };
 
     // Pool of test threads
     class thread_pool
     {
+        class barrier
+        {
+        public:
+            barrier()
+                : m_count( 0 )
+            {}
+
+            void reset( size_t count )
+            {
+                std::unique_lock< std::mutex > lock( m_mtx );
+                m_count = count;
+            }
+
+            bool wait()
+            {
+                std::unique_lock< std::mutex > lock( m_mtx );
+                if ( --m_count == 0 ) {
+                    m_cv.notify_all();
+                    return true;
+                }
+
+                while ( m_count != 0 )
+                    m_cv.wait( lock );
+
+                return false;
+            }
+
+        private:
+            size_t      m_count;
+            std::mutex  m_mtx;
+            std::condition_variable m_cv;
+        };
+
+        class initial_gate
+        {
+        public:
+            initial_gate()
+                : m_ready( false )
+            {}
+
+            void wait()
+            {
+                std::unique_lock< std::mutex > lock( m_mtx );
+                while ( !m_ready )
+                    m_cv.wait( lock );
+            }
+
+            void ready()
+            {
+                std::unique_lock< std::mutex > lock( m_mtx );
+                m_ready = true;
+                m_cv.notify_all();
+            }
+
+            void reset()
+            {
+                std::unique_lock< std::mutex > lock( m_mtx );
+                m_ready = false;
+            }
+
+        private:
+            std::mutex  m_mtx;
+            std::condition_variable m_cv;
+            bool        m_ready;
+        };
+
     public:
-        explicit thread_pool( ::testing::Test& fixture )
-            : m_fixture( fixture )
-            , m_bRunning( false )
-            , m_bStopped( false )
-            , m_doneCount( 0 )
+        explicit thread_pool( ::testing::Test& fx )
+            : m_fixture( fx )
+            , m_bTimeElapsed( false )
         {}
 
         ~thread_pool()
         {
-            for ( auto t : m_threads )
-                delete t;
+            clear();
         }
 
-        void add( thread& what )
+        void add( thread * what )
         {
-            m_threads.push_back( &what );
-            what.run();
+            m_workers.push_back( what );
         }
 
-        void add( thread& what, size_t count )
+        void add( thread * what, size_t count )
         {
             add( what );
             for ( size_t i = 1; i < count; ++i ) {
-                thread * p = what.clone();
-                add( *p );
+                thread * p = what->clone();
+                add( p );
             }
         }
 
         std::chrono::milliseconds run()
         {
-            m_bStopped = false;
-            m_doneCount = 0;
+            return run( std::chrono::seconds::zero());
+        }
 
-            auto time_start = std::chrono::steady_clock::now();
+        std::chrono::milliseconds run( std::chrono::seconds duration )
+        {
+            m_startBarrier.reset( m_workers.size() + 1 );
+            m_stopBarrier.reset( m_workers.size() + 1 );
 
-            m_bRunning = true;
-            m_cvStart.notify_all();
+            // Create threads
+            std::vector< std::thread > threads;
+            threads.reserve( m_workers.size());
+            for ( auto w : m_workers )
+                threads.emplace_back( &thread::run, w );
 
-            {
-                scoped_lock l( m_cvMutex );
-                while ( m_doneCount != m_threads.size() )
-                    m_cvDone.wait( l );
-                m_bStopped = true;
+            // The pool is intialized
+            m_startPoint.ready();
+
+            m_bTimeElapsed.store( false, std::memory_order_release );
+
+            auto native_duration = std::chrono::duration_cast<std::chrono::steady_clock::duration>(duration);
+
+            // The pool is ready to start all workers
+            m_startBarrier.wait();
+
+            auto time_start = std::chrono::steady_clock::now();
+            auto const expected_end = time_start + native_duration;
+
+            if ( duration != std::chrono::seconds::zero()) {
+                for ( ;; ) {
+                    std::this_thread::sleep_for( native_duration );
+                    auto time_now = std::chrono::steady_clock::now();
+                    if ( time_now >= expected_end )
+                        break;
+                    native_duration = expected_end - time_now;
+                }
             }
-            auto time_end = std::chrono::steady_clock::now();
+            m_bTimeElapsed.store( true, std::memory_order_release );
+
+            // Waiting for all workers done
+            m_stopBarrier.wait();
 
-            m_cvStop.notify_all();
+            auto time_end = std::chrono::steady_clock::now();
 
-            for ( auto t : m_threads )
-                t->join();
+            for ( auto& t : threads )
+                t.join();
 
-            return m_testDuration = time_end - time_start;
+            return m_testDuration = std::chrono::duration_cast<std::chrono::milliseconds>(time_end - time_start);
         }
 
-        size_t size() const             { return m_threads.size(); }
-        thread& get( size_t idx ) const { return *m_threads.at( idx ); }
+        size_t size() const             { return m_workers.size(); }
+        thread& get( size_t idx ) const { return *m_workers.at( idx ); }
 
         template <typename Fixture>
         Fixture& fixture()
@@ -151,73 +242,66 @@ namespace cds_test {
 
         std::chrono::milliseconds duration() const { return m_testDuration; }
 
+        void clear()
+        {
+            for ( auto t : m_workers )
+                delete t;
+            m_workers.clear();
+            m_startPoint.reset();
+        }
+
+        void reset()
+        {
+            clear();
+        }
+
     protected: // thread interface
         size_t get_next_id()
         {
-            return m_threads.size();
+            return m_workers.size();
         }
 
-        void    ready_to_start( thread& /*who*/ )
+        void ready_to_start( thread& /*who*/ )
         {
             // Called from test thread
 
-            // Wait for all thread created
-            scoped_lock l( m_cvMutex );
-            while ( !m_bRunning )
-                m_cvStart.wait( l );
+            // Wait until the pool is ready
+            m_startPoint.wait();
+
+            // Wait until all thread ready
+            m_startBarrier.wait();
         }
 
-        void    thread_done( thread& /*who*/ )
+        void thread_done( thread& /*who*/ )
         {
             // Called from test thread
-
-            {
-                scoped_lock l( m_cvMutex );
-                ++m_doneCount;
-            }
-
-            // Tell pool that the thread is done
-            m_cvDone.notify_all();
-            
-            // Wait for all thread done
-            {
-                scoped_lock l( m_cvMutex );
-                while ( !m_bStopped )
-                    m_cvStop.wait( l );
-            }
+            m_stopBarrier.wait();
         }
 
     private:
         friend class thread;
 
         ::testing::Test&        m_fixture;
-        std::vector<thread *>   m_threads;
-
-        typedef std::unique_lock<std::mutex> scoped_lock;
-        std::mutex              m_cvMutex;
-        std::condition_variable m_cvStart;
-        std::condition_variable m_cvStop;
-        std::condition_variable m_cvDone;
+        std::vector<thread *>   m_workers;
 
-        volatile bool   m_bRunning;
-        volatile bool   m_bStopped;
-        volatile size_t m_doneCount;
+        initial_gate            m_startPoint;
+        barrier                 m_startBarrier;
+        barrier                 m_stopBarrier;
 
+        std::atomic<bool> m_bTimeElapsed;
         std::chrono::milliseconds m_testDuration;
     };
 
-    inline thread::thread( thread_pool& master, int type = 0 )
+    inline thread::thread( thread_pool& master, int type /*= 0*/ )
         : m_pool( master )
         , m_type( type )
         , m_id( master.get_next_id())
-        , m_impl( &run, this )
     {}
 
     inline thread::thread( thread const& sample )
         : m_pool( sample.m_pool )
         , m_type( sample.m_type )
-        , m_id( m_pool.get_next_id() )
-        , m_impl( &run, this )
+        , m_id( m_pool.get_next_id())
     {}
 
     inline void thread::run()
@@ -229,6 +313,11 @@ namespace cds_test {
         TearDown();
     }
 
+    inline bool thread::time_elapsed() const
+    {
+        return m_pool.m_bTimeElapsed.load( std::memory_order_acquire );
+    }
+
 } // namespace cds_test
 
 #endif // CDSTEST_THREAD_H