common.h: move common code (non-user) to header
[model-checker.git] / libthreads.c
index 4545ee3640a79a542e1c61f199bd37853c4c9231..3eae58fca24da4556c9dfb993faba81105e91f0b 100644 (file)
 #include <string.h>
 #include <stdlib.h>
-#include <ucontext.h>
-#include <stdio.h>
 
-//#define CONFIG_DEBUG
-
-#ifdef CONFIG_DEBUG
-#define DBG() do { printf("Here: %s, L%d\n", __func__, __LINE__); } while (0)
-#define DEBUG(fmt, ...) printf(fmt, ##__VA_ARGS__)
-#else
-#define DBG()
-#define DEBUG(fmt, ...)
-#endif
+#include "libthreads.h"
+#include "common.h"
 
 #define STACK_SIZE (1024 * 1024)
 
-struct thread {
-       void (*start_routine);
-       void *arg;
-       ucontext_t context;
-       void *stack;
-       int index;
-};
-
-static struct thread *current;
-static ucontext_t *cleanup;
+static struct thread *current, *main_thread;
 
 static void *stack_allocate(size_t size)
 {
        return malloc(size);
 }
 
-int thread_create(struct thread *t, void (*start_routine), void *arg)
+static void stack_free(void *stack)
 {
-       static int created;
-       ucontext_t local;
+       free(stack);
+}
 
-       DBG();
+static int create_context(struct thread *t)
+{
+       int ret;
 
-       t->index = created++;
-       DEBUG("create thread %d\n", t->index);
+       memset(&t->context, 0, sizeof(t->context));
+       ret = getcontext(&t->context);
+       if (ret)
+               return ret;
 
-       t->start_routine = start_routine;
-       t->arg = arg;
+       /* t->start_routine == NULL means this is our initial context */
+       if (!t->start_routine)
+               return 0;
 
-       /* Initialize state */
-       getcontext(&t->context);
+       /* Initialize new managed context */
        t->stack = stack_allocate(STACK_SIZE);
        t->context.uc_stack.ss_sp = t->stack;
        t->context.uc_stack.ss_size = STACK_SIZE;
        t->context.uc_stack.ss_flags = 0;
-       if (current)
-               t->context.uc_link = &current->context;
-       else
-               t->context.uc_link = cleanup;
+       t->context.uc_link = &main_thread->context;
        makecontext(&t->context, t->start_routine, 1, t->arg);
 
        return 0;
 }
 
-void thread_start(struct thread *t)
+static int create_initial_thread(struct thread *t)
+{
+       memset(t, 0, sizeof(*t));
+       return create_context(t);
+}
+
+static int thread_swap(struct thread *old, struct thread *new)
+{
+       return swapcontext(&old->context, &new->context);
+}
+
+static int thread_yield()
 {
+       struct thread *old, *next;
+
        DBG();
+       old = current;
+       schedule_add_thread(old);
+       schedule_choose_next(&next);
+       current = next;
+       DEBUG("(%d, %d)\n", old->index, next->index);
+       return thread_swap(old, next);
+}
+
+static void thread_dispose(struct thread *t)
+{
+       DEBUG("completed thread %d\n", thread_current()->index);
+       t->completed = 1;
+       stack_free(t->stack);
+}
+
+static void thread_wait_finish()
+{
+       struct thread *next;
 
-       if (current) {
-               struct thread *old = current;
-               current = t;
-               swapcontext(&old->context, &current->context);
-       } else {
-               current = t;
-               swapcontext(cleanup, &current->context);
-       }
        DBG();
+
+       do {
+               if (current)
+                       thread_dispose(current);
+               schedule_choose_next(&next);
+               current = next;
+       } while (next && !thread_swap(main_thread, next));
 }
 
-void a(int *idx)
+int thread_create(struct thread *t, void (*start_routine), void *arg)
+{
+       static int created = 1;
+       int ret = 0;
+
+       DBG();
+
+       memset(t, 0, sizeof(*t));
+       t->index = created++;
+       DEBUG("create thread %d\n", t->index);
+
+       t->start_routine = start_routine;
+       t->arg = arg;
+
+       /* Initialize state */
+       ret = create_context(t);
+       if (ret)
+               return ret;
+
+       schedule_add_thread(t);
+       return 0;
+}
+
+void thread_join(struct thread *t)
+{
+       while (!t->completed)
+               thread_yield();
+}
+
+struct thread *thread_current(void)
+{
+       return current;
+}
+
+void a(int *parm)
 {
        int i;
 
-       for (i = 0; i < 10; i++)
-               printf("Thread %d, loop %d\n", *idx, i);
+       for (i = 0; i < 10; i++) {
+               printf("Thread %d, magic number %d, loop %d\n", thread_current()->index, *parm, i);
+               if (i % 2)
+                       thread_yield();
+       }
 }
 
 void user_main()
 {
        struct thread t1, t2;
-       int i = 1, j = 2;
+       int i = 17, j = 13;
 
+       printf("%s() creating 2 threads\n", __func__);
        thread_create(&t1, &a, &i);
        thread_create(&t2, &a, &j);
 
-       printf("user_main() is going to start 2 threads\n");
-       thread_start(&t1);
-       thread_start(&t2);
-       printf("user_main() is finished\n");
+       thread_join(&t1);
+       thread_join(&t2);
+       printf("%s() is finished\n", __func__);
 }
 
 int main()
 {
-       struct thread t;
-       ucontext_t main_context;
-
-       cleanup = &main_context;
+       struct thread user_thread;
 
-       thread_create(&t, &user_main, NULL);
+       main_thread = malloc(sizeof(struct thread));
+       create_initial_thread(main_thread);
 
-       thread_start(&t);
+       /* Start user program */
+       thread_create(&user_thread, &user_main, NULL);
 
-       DBG();
+       /* Wait for all threads to complete */
+       thread_wait_finish();
 
        DEBUG("Exiting\n");
        return 0;