From: weiyu <weiyuluo1232@gmail.com>
Date: Fri, 24 Jul 2020 19:18:16 +0000 (-0700)
Subject: Fix missed normal loads/stores
X-Git-Url: http://demsky.eecs.uci.edu/git/?a=commitdiff_plain;h=HEAD;p=c11llvm.git

Fix missed normal loads/stores
---

diff --git a/CDSPass.cpp b/CDSPass.cpp
index 076ddba..a33738a 100644
--- a/CDSPass.cpp
+++ b/CDSPass.cpp
@@ -273,9 +273,9 @@ void CDSPass::initializeCallbacks(Module &M) {
 		SmallString<32> AtomicStoreName("cds_atomic_store" + BitSizeStr);
 
 		CDSLoad[i]  = checkCDSPassInterfaceFunction(
-							M.getOrInsertFunction(LoadName, Attr, VoidTy, PtrTy));
+							M.getOrInsertFunction(LoadName, Attr, VoidTy, Int8PtrTy));
 		CDSStore[i] = checkCDSPassInterfaceFunction(
-							M.getOrInsertFunction(StoreName, Attr, VoidTy, PtrTy));
+							M.getOrInsertFunction(StoreName, Attr, VoidTy, Int8PtrTy));
 		CDSVolatileLoad[i]  = checkCDSPassInterfaceFunction(
 								M.getOrInsertFunction(VolatileLoadName,
 								Attr, Ty, PtrTy, Int8PtrTy));
@@ -652,19 +652,9 @@ bool CDSPass::instrumentLoadOrStore(Instruction *I,
 	}
 
 	// TODO: unaligned reads and writes
-
 	Value *OnAccessFunc = nullptr;
 	OnAccessFunc = IsWrite ? CDSStore[Idx] : CDSLoad[Idx];
-
-	Type *ArgType = IRB.CreatePointerCast(Addr, Addr->getType())->getType();
-
-	if ( ArgType != Int8PtrTy && ArgType != Int16PtrTy && 
-			ArgType != Int32PtrTy && ArgType != Int64PtrTy ) {
-		// if other types of load or stores are passed in
-		return false;	
-	}
-
-	IRB.CreateCall(OnAccessFunc, IRB.CreatePointerCast(Addr, Addr->getType()));
+	IRB.CreateCall(OnAccessFunc, IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()));
 	if (IsWrite) NumInstrumentedWrites++;
 	else         NumInstrumentedReads++;
 	return true;
@@ -672,10 +662,6 @@ bool CDSPass::instrumentLoadOrStore(Instruction *I,
 
 bool CDSPass::instrumentVolatile(Instruction * I, const DataLayout &DL) {
 	IRBuilder<> IRB(I);
-	const unsigned ByteSize = 1U << Idx;
-	const unsigned BitSize = ByteSize * 8;
-	Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
-	Type *PtrTy = Ty->getPointerTo();
 	Value *position = getPosition(I, IRB);
 
 	if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
@@ -683,19 +669,25 @@ bool CDSPass::instrumentVolatile(Instruction * I, const DataLayout &DL) {
 		int Idx=getMemoryAccessFuncIndex(Addr, DL);
 		if (Idx < 0)
 			return false;
-
+		const unsigned ByteSize = 1U << Idx;
+		const unsigned BitSize = ByteSize * 8;
+		Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
+		Type *PtrTy = Ty->getPointerTo();
 		Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy), position};
+
 		Type *OrigTy = cast<PointerType>(Addr->getType())->getElementType();
 		Value *C = IRB.CreateCall(CDSVolatileLoad[Idx], Args);
 		Value *Cast = IRB.CreateBitOrPointerCast(C, OrigTy);
 		I->replaceAllUsesWith(Cast);
 	} else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
-		assert( SI->isVolatile() );
 		Value *Addr = SI->getPointerOperand();
 		int Idx=getMemoryAccessFuncIndex(Addr, DL);
 		if (Idx < 0)
 			return false;
-
+		const unsigned ByteSize = 1U << Idx;
+		const unsigned BitSize = ByteSize * 8;
+		Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
+		Type *PtrTy = Ty->getPointerTo();
 		Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy),
 					  IRB.CreateBitOrPointerCast(SI->getValueOperand(), Ty),
 					  position};