/** Hashtable class. By default it is snapshotting, but you can pass
in your own allocation functions. */
-template<typename _Key, typename _Val, typename _KeyInt, int _Shift, void * (* _malloc)(size_t)=malloc, void * (* _calloc)(size_t, size_t)=calloc, void (*_free)(void *)=free>
+template<typename _Key, typename _Val, typename _KeyInt, int _Shift=0, void * (* _malloc)(size_t)=malloc, void * (* _calloc)(size_t, size_t)=calloc, void (*_free)(void *)=free>
class HashTable {
public:
HashTable(unsigned int initialcapacity=1024, double factor=0.5) {
table[(((_KeyInt)key)&mask)>>_Shift]=newptr;
}
+ /** Put a key entry into the table. */
+ _Val * ensureptr(_Key key) {
+ if(size > threshold) {
+ //Resize
+ unsigned int newsize = capacity << 1;
+ resize(newsize);
+ }
+
+ struct hashlistnode<_Key,_Val> *ptr = table[(((_KeyInt)key) & mask)>>_Shift];
+ size++;
+ struct hashlistnode<_Key,_Val> *search = ptr;
+
+ while(search!=NULL) {
+ if (search->key==key) {
+ return &search->val;
+ }
+ search=search->next;
+ }
+
+ struct hashlistnode<_Key,_Val> *newptr=(struct hashlistnode<_Key,_Val> *)new struct hashlistnode<_Key,_Val>;
+ newptr->key=key;
+ newptr->next=ptr;
+ table[(((_KeyInt)key)&mask)>>_Shift]=newptr;
+ return &newptr->val;
+ }
+
/** Lookup the corresponding value for the given key. */
_Val get(_Key key) {
struct hashlistnode<_Key,_Val> *search = table[(((_KeyInt)key) & mask)>>_Shift];
return (_Val)0;
}
+ /** Lookup the corresponding value for the given key. */
+ _Val * getptr(_Key key) {
+ struct hashlistnode<_Key,_Val> *search = table[(((_KeyInt)key) & mask)>>_Shift];
+
+ while(search!=NULL) {
+ if (search->key==key) {
+ return & search->val;
+ }
+ search=search->next;
+ }
+ return (_Val *) NULL;
+ }
+
/** Check whether the table contains a value for the given key. */
bool contains(_Key key) {
struct hashlistnode<_Key,_Val> *search = table[(((_KeyInt)key) & mask)>>_Shift];
diverge(NULL),
nextThread(THREAD_ID_T_NONE),
action_trace(new action_list_t()),
- thread_map(new std::map<int, Thread *>),
- obj_map(new std::map<const void *, action_list_t>()),
- obj_thrd_map(new std::map<void *, std::vector<action_list_t> >()),
+ thread_map(new HashTable<int, Thread *, int>()),
+ obj_map(new HashTable<const void *, action_list_t, uintptr_t, 4>()),
+ obj_thrd_map(new HashTable<void *, std::vector<action_list_t>, uintptr_t, 4 >()),
thrd_last_action(new std::vector<ModelAction *>(1)),
node_stack(new NodeStack()),
next_backtrack(NULL),
/** @brief Destructor */
ModelChecker::~ModelChecker()
{
- std::map<int, Thread *>::iterator it;
+ /* std::map<int, Thread *>::iterator it;
for (it = thread_map->begin(); it != thread_map->end(); it++)
- delete (*it).second;
+ delete (*it).second;*/
delete thread_map;
delete obj_thrd_map;
Thread *t;
if (nextThread == THREAD_ID_T_NONE)
return NULL;
- t = (*thread_map)[id_to_int(nextThread)];
+ t = thread_map->get(id_to_int(nextThread));
ASSERT(t != NULL);
return NULL;
}
/* linear search: from most recent to oldest */
- action_list_t *list = &(*obj_map)[act->get_location()];
+ action_list_t *list = obj_map->ensureptr(act->get_location());
action_list_t::reverse_iterator rit;
for (rit = list->rbegin(); rit != list->rend(); rit++) {
ModelAction *prev = *rit;
* @param rf The action that curr reads from. Must be a write.
*/
void ModelChecker::r_modification_order(ModelAction * curr, const ModelAction *rf) {
- std::vector<action_list_t> *thrd_lists = &(*obj_thrd_map)[curr->get_location()];
+ std::vector<action_list_t> *thrd_lists = obj_thrd_map->ensureptr(curr->get_location());
unsigned int i;
ASSERT(curr->is_read());
* @param curr The current action. Must be a write.
*/
void ModelChecker::w_modification_order(ModelAction * curr) {
- std::vector<action_list_t> *thrd_lists = &(*obj_thrd_map)[curr->get_location()];
+ std::vector<action_list_t> *thrd_lists = obj_thrd_map->ensureptr(curr->get_location());
unsigned int i;
ASSERT(curr->is_write());
int tid = id_to_int(act->get_tid());
action_trace->push_back(act);
- (*obj_map)[act->get_location()].push_back(act);
+ obj_map->ensureptr(act->get_location())->push_back(act);
- std::vector<action_list_t> *vec = &(*obj_thrd_map)[act->get_location()];
+ std::vector<action_list_t> *vec = obj_thrd_map->ensureptr(act->get_location());
if (tid >= (int)vec->size())
vec->resize(next_thread_id);
(*vec)[tid].push_back(act);
*/
ModelAction * ModelChecker::get_last_seq_cst(const void *location)
{
- action_list_t *list = &(*obj_map)[location];
+ action_list_t *list = obj_map->ensureptr(location);
/* Find: max({i in dom(S) | seq_cst(t_i) && isWrite(t_i) && samevar(t_i, t)}) */
action_list_t::reverse_iterator rit;
for (rit = list->rbegin(); rit != list->rend(); rit++)
*/
void ModelChecker::build_reads_from_past(ModelAction *curr)
{
- std::vector<action_list_t> *thrd_lists = &(*obj_thrd_map)[curr->get_location()];
+ std::vector<action_list_t> *thrd_lists = obj_thrd_map->ensureptr(curr->get_location());
unsigned int i;
ASSERT(curr->is_read());
int ModelChecker::add_thread(Thread *t)
{
- (*thread_map)[id_to_int(t->get_id())] = t;
+ thread_map->put(id_to_int(t->get_id()), t);
scheduler->add_thread(t);
return 0;
}
#define __MODEL_H__
#include <list>
-#include <map>
#include <vector>
#include <cstddef>
#include <ucontext.h>
int add_thread(Thread *t);
void remove_thread(Thread *t);
- Thread * get_thread(thread_id_t tid) { return (*thread_map)[id_to_int(tid)]; }
+ Thread * get_thread(thread_id_t tid) { return thread_map->get(id_to_int(tid)); }
thread_id_t get_next_id();
int get_num_threads();
ucontext_t *system_context;
action_list_t *action_trace;
- std::map<int, Thread *> *thread_map;
+ HashTable<int, Thread *, int> *thread_map;
/** Per-object list of actions. Maps an object (i.e., memory location)
* to a trace of all actions performed on the object. */
- std::map<const void *, action_list_t> *obj_map;
+ HashTable<const void *, action_list_t, uintptr_t, 4> *obj_map;
- std::map<void *, std::vector<action_list_t> > *obj_thrd_map;
+ HashTable<void *, std::vector<action_list_t>, uintptr_t, 4 > *obj_thrd_map;
std::vector<ModelAction *> *thrd_last_action;
NodeStack *node_stack;
ModelAction *next_backtrack;
static mspace sStaticSpace = NULL;
#endif
+/** Non-snapshotting calloc for our use. */
+void *MYCALLOC(size_t count, size_t size) {
+#if USE_MPROTECT_SNAPSHOT
+ static void *(*callocp)(size_t count, size_t size)=NULL;
+ char *error;
+ void *ptr;
+
+ /* get address of libc malloc */
+ if (!callocp) {
+ callocp = ( void * ( * )( size_t, size_t ) )dlsym(RTLD_NEXT, "calloc");
+ if ((error = dlerror()) != NULL) {
+ fputs(error, stderr);
+ exit(EXIT_FAILURE);
+ }
+ }
+ ptr = callocp(count, size);
+ return ptr;
+#else
+ if( !snapshotrecord) {
+ createSharedMemory();
+ }
+ if( NULL == sStaticSpace )
+ sStaticSpace = create_mspace_with_base( ( void * )( snapshotrecord->mSharedMemoryBase ), SHARED_MEMORY_DEFAULT -sizeof( struct SnapShot ), 1 );
+ return mspace_calloc( sStaticSpace, count, size );
+#endif
+}
+
/** Non-snapshotting malloc for our use. */
void *MYMALLOC(size_t size) {
#if USE_MPROTECT_SNAPSHOT
- static void *(*mallocp)(size_t size);
+ static void *(*mallocp)(size_t size)=NULL;
char *error;
void *ptr;
#define SNAPSHOTALLOC
void *MYMALLOC(size_t size);
+void *MYCALLOC(size_t count, size_t size);
void MYFREE(void *ptr);
void system_free( void * ptr );
#include <unistd.h>
#include <signal.h>
#include <stdlib.h>
-#include <map>
+#include "hashtable.h"
#include <cstring>
#include <cstdio>
#include "snapshot.h"
*/
void rollBack( snapshot_id theID ){
#if USE_MPROTECT_SNAPSHOT
- std::map< void *, bool, std::less< void * >, MyAlloc< std::pair< const void *, bool > > > duplicateMap;
+ HashTable< void *, bool, uintptr_t, 4, MYMALLOC, MYCALLOC, MYFREE> duplicateMap;
for(unsigned int region=0; region<snapshotrecord->lastRegion;region++) {
if( mprotect(snapshotrecord->regionsToSnapShot[region].basePtr, snapshotrecord->regionsToSnapShot[region].sizeInPages*sizeof(struct SnapShotPage), PROT_READ | PROT_WRITE ) == -1 ){
perror("mprotect");
}
}
for(unsigned int page=snapshotrecord->snapShots[theID].firstBackingPage; page<snapshotrecord->lastBackingPage; page++) {
- bool oldVal = false;
- if( duplicateMap.find( snapshotrecord->backingRecords[page].basePtrOfPage ) != duplicateMap.end() ){
- oldVal = true;
- }
- else{
- duplicateMap[ snapshotrecord->backingRecords[page].basePtrOfPage ] = true;
- }
- if( !oldVal ){
+ if( !duplicateMap.contains(snapshotrecord->backingRecords[page].basePtrOfPage )) {
+ duplicateMap.put(snapshotrecord->backingRecords[page].basePtrOfPage, true);
memcpy(snapshotrecord->backingRecords[page].basePtrOfPage, &snapshotrecord->backingStore[page], sizeof(struct SnapShotPage));
}
}