schedule: return next thread pointer directly
[model-checker.git] / libthreads.c
1 #include <string.h>
2 #include <stdlib.h>
3
4 #include "libthreads.h"
5 #include "schedule.h"
6 #include "common.h"
7
8 #define STACK_SIZE (1024 * 1024)
9
10 static struct thread *current, *main_thread;
11
12 static void *stack_allocate(size_t size)
13 {
14         return malloc(size);
15 }
16
17 static void stack_free(void *stack)
18 {
19         free(stack);
20 }
21
22 static int create_context(struct thread *t)
23 {
24         int ret;
25
26         memset(&t->context, 0, sizeof(t->context));
27         ret = getcontext(&t->context);
28         if (ret)
29                 return ret;
30
31         /* t->start_routine == NULL means this is our initial context */
32         if (!t->start_routine)
33                 return 0;
34
35         /* Initialize new managed context */
36         t->stack = stack_allocate(STACK_SIZE);
37         t->context.uc_stack.ss_sp = t->stack;
38         t->context.uc_stack.ss_size = STACK_SIZE;
39         t->context.uc_stack.ss_flags = 0;
40         t->context.uc_link = &main_thread->context;
41         makecontext(&t->context, t->start_routine, 1, t->arg);
42
43         return 0;
44 }
45
46 static int create_initial_thread(struct thread *t)
47 {
48         memset(t, 0, sizeof(*t));
49         return create_context(t);
50 }
51
52 static int thread_swap(struct thread *old, struct thread *new)
53 {
54         return swapcontext(&old->context, &new->context);
55 }
56
57 static int thread_yield()
58 {
59         struct thread *old, *next;
60
61         DBG();
62         old = current;
63         schedule_add_thread(old);
64         next = schedule_choose_next();
65         current = next;
66         DEBUG("(%d, %d)\n", old->index, next->index);
67         return thread_swap(old, next);
68 }
69
70 static void thread_dispose(struct thread *t)
71 {
72         DEBUG("completed thread %d\n", thread_current()->index);
73         t->completed = 1;
74         stack_free(t->stack);
75 }
76
77 static void thread_wait_finish()
78 {
79         struct thread *next;
80
81         DBG();
82
83         do {
84                 if (current)
85                         thread_dispose(current);
86                 next = schedule_choose_next();
87                 current = next;
88         } while (next && !thread_swap(main_thread, next));
89 }
90
91 int thread_create(struct thread *t, void (*start_routine), void *arg)
92 {
93         static int created = 1;
94         int ret = 0;
95
96         DBG();
97
98         memset(t, 0, sizeof(*t));
99         t->index = created++;
100         DEBUG("create thread %d\n", t->index);
101
102         t->start_routine = start_routine;
103         t->arg = arg;
104
105         /* Initialize state */
106         ret = create_context(t);
107         if (ret)
108                 return ret;
109
110         schedule_add_thread(t);
111         return 0;
112 }
113
114 void thread_join(struct thread *t)
115 {
116         while (!t->completed)
117                 thread_yield();
118 }
119
120 struct thread *thread_current(void)
121 {
122         return current;
123 }
124
125 void a(int *parm)
126 {
127         int i;
128
129         for (i = 0; i < 10; i++) {
130                 printf("Thread %d, magic number %d, loop %d\n", thread_current()->index, *parm, i);
131                 if (i % 2)
132                         thread_yield();
133         }
134 }
135
136 void user_main()
137 {
138         struct thread t1, t2;
139         int i = 17, j = 13;
140
141         printf("%s() creating 2 threads\n", __func__);
142         thread_create(&t1, &a, &i);
143         thread_create(&t2, &a, &j);
144
145         thread_join(&t1);
146         thread_join(&t2);
147         printf("%s() is finished\n", __func__);
148 }
149
150 int main()
151 {
152         struct thread user_thread;
153
154         main_thread = malloc(sizeof(struct thread));
155         create_initial_thread(main_thread);
156
157         /* Start user program */
158         thread_create(&user_thread, &user_main, NULL);
159
160         /* Wait for all threads to complete */
161         thread_wait_finish();
162
163         DEBUG("Exiting\n");
164         return 0;
165 }