Compare exchange strong / weak may be called without the failure memory order
[c11llvm.git] / CDSPass.cpp
index dede6bc6b0652e27aa59524f8b15644568489a1f..9a9e95c1e9ef7f6992ad47caf98674327ea686d4 100644 (file)
@@ -89,9 +89,8 @@ Type * Int64PtrTy;
 Type * VoidTy;
 
 static const size_t kNumberOfAccessSizes = 4;
-static const int volatile_order = 6;
 
-int getAtomicOrderIndex(AtomicOrdering order){
+int getAtomicOrderIndex(AtomicOrdering order) {
        switch (order) {
                case AtomicOrdering::Monotonic: 
                        return (int)AtomicOrderingCABI::relaxed;
@@ -111,11 +110,46 @@ int getAtomicOrderIndex(AtomicOrdering order){
        }
 }
 
+AtomicOrderingCABI indexToAtomicOrder(int index) {
+       switch (index) {
+               case 0:
+                       return AtomicOrderingCABI::relaxed;
+               case 1:
+                       return AtomicOrderingCABI::consume;
+               case 2:
+                       return AtomicOrderingCABI::acquire;
+               case 3:
+                       return AtomicOrderingCABI::release;
+               case 4:
+                       return AtomicOrderingCABI::acq_rel;
+               case 5:
+                       return AtomicOrderingCABI::seq_cst;
+               default:
+                       errs() << "Bad Atomic index\n";
+                       return AtomicOrderingCABI::seq_cst;
+       }
+}
+
+/* According to atomic_base.h: __cmpexch_failure_order */
+int AtomicCasFailureOrderIndex(int index) {
+       AtomicOrderingCABI succ_order = indexToAtomicOrder(index);
+       AtomicOrderingCABI fail_order;
+       if (succ_order == AtomicOrderingCABI::acq_rel)
+               fail_order = AtomicOrderingCABI::acquire;
+       else if (succ_order == AtomicOrderingCABI::release) 
+               fail_order = AtomicOrderingCABI::relaxed;
+       else
+               fail_order = succ_order;
+
+       return (int) fail_order;
+}
+
 namespace {
        struct CDSPass : public FunctionPass {
                static char ID;
                CDSPass() : FunctionPass(ID) {}
                bool runOnFunction(Function &F) override; 
+               StringRef getPassName() const override;
 
        private:
                void initializeCallbacks(Module &M);
@@ -151,6 +185,10 @@ namespace {
        };
 }
 
+StringRef CDSPass::getPassName() const {
+       return "CDSPass";
+}
+
 static bool isVtableAccess(Instruction *I) {
        if (MDNode *Tag = I->getMetadata(LLVMContext::MD_tbaa))
                return Tag->isTBAAVtableAccess();
@@ -199,9 +237,9 @@ void CDSPass::initializeCallbacks(Module &M) {
                CDSLoad[i]  = M.getOrInsertFunction(LoadName, VoidTy, PtrTy);
                CDSStore[i] = M.getOrInsertFunction(StoreName, VoidTy, PtrTy);
                CDSVolatileLoad[i]  = M.getOrInsertFunction(VolatileLoadName,
-                                                                       Ty, PtrTy, OrdTy, Int8PtrTy);
+                                                                       Ty, PtrTy, Int8PtrTy);
                CDSVolatileStore[i] = M.getOrInsertFunction(VolatileStoreName, 
-                                                                       VoidTy, PtrTy, Ty, OrdTy, Int8PtrTy);
+                                                                       VoidTy, PtrTy, Ty, Int8PtrTy);
                CDSAtomicInit[i] = M.getOrInsertFunction(AtomicInitName, 
                                                                VoidTy, PtrTy, Ty, Int8PtrTy);
                CDSAtomicLoad[i]  = M.getOrInsertFunction(AtomicLoadName, 
@@ -328,6 +366,7 @@ bool CDSPass::runOnFunction(Function &F) {
 
                bool Res = false;
                bool HasAtomic = false;
+               bool HasVolatile = false;
                const DataLayout &DL = F.getParent()->getDataLayout();
 
                // errs() << "--- " << F.getName() << "---\n";
@@ -342,9 +381,10 @@ bool CDSPass::runOnFunction(Function &F) {
                                        StoreInst *SI = dyn_cast<StoreInst>(&I);
                                        bool isVolatile = ( LI ? LI->isVolatile() : SI->isVolatile() );
 
-                                       if (isVolatile)
+                                       if (isVolatile) {
                                                VolatileLoadsAndStores.push_back(&I);
-                                       else
+                                               HasVolatile = true;
+                                       } else
                                                LocalLoadsAndStores.push_back(&I);
                                } else if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
                                        // not implemented yet
@@ -367,7 +407,7 @@ bool CDSPass::runOnFunction(Function &F) {
                }
 
                // only instrument functions that contain atomics
-               if (Res && HasAtomic) {
+               if (Res && ( HasAtomic || HasVolatile) ) {
                        IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI());
                        /* Unused for now
                        Value *ReturnAddress = IRB.CreateCall(
@@ -385,8 +425,6 @@ bool CDSPass::runOnFunction(Function &F) {
 
                        Res = true;
                }
-
-               F.dump();
        }
 
        return false;
@@ -508,8 +546,7 @@ bool CDSPass::instrumentVolatile(Instruction * I, const DataLayout &DL) {
                if (Idx < 0)
                        return false;
 
-               Value *order = ConstantInt::get(OrdTy, volatile_order);
-               Value *args[] = {Addr, order, position};
+               Value *args[] = {Addr, position};
                Instruction* funcInst=CallInst::Create(CDSVolatileLoad[Idx], args);
                ReplaceInstWithInst(LI, funcInst);
        } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
@@ -520,8 +557,7 @@ bool CDSPass::instrumentVolatile(Instruction * I, const DataLayout &DL) {
                        return false;
 
                Value *val = SI->getValueOperand();
-               Value *order = ConstantInt::get(OrdTy, volatile_order);
-               Value *args[] = {Addr, val, order, position};
+               Value *args[] = {Addr, val, position};
                Instruction* funcInst=CallInst::Create(CDSVolatileStore[Idx], args);
                ReplaceInstWithInst(SI, funcInst);
        } else {
@@ -682,8 +718,15 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
        // atomic_init; args = {obj, order}
        if (funName.contains("atomic_init")) {
+               Value *OrigVal = parameters[1];
+
                Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
-               Value *val = IRB.CreateBitOrPointerCast(parameters[1], Ty);
+               Value *val;
+               if (OrigVal->getType()->isPtrOrPtrVectorTy())
+                       val = IRB.CreatePointerCast(OrigVal, Ty);
+               else
+                       val = IRB.CreateIntCast(OrigVal, Ty, true);
+
                Value *args[] = {ptr, val, position};
 
                Instruction* funcInst = CallInst::Create(CDSAtomicInit[Idx], args);
@@ -749,12 +792,17 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
                return true;
        } else if (funName.contains("atomic") && 
-                                       funName.contains("EEEE5store") ) {
+                                       funName.contains("store") ) {
                // does this version of call always have an atomic order as an argument?
                Value *OrigVal = parameters[1];
 
                Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
-               Value *val = IRB.CreatePointerCast(OrigVal, Ty);
+               Value *val;
+               if (OrigVal->getType()->isPtrOrPtrVectorTy())
+                       val = IRB.CreatePointerCast(OrigVal, Ty);
+               else
+                       val = IRB.CreateIntCast(OrigVal, Ty, true);
+
                Value *order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
                Value *args[] = {ptr, val, order, position};
 
@@ -766,7 +814,12 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
        // atomic_fetch_*; args = {obj, val, order}
        if (funName.contains("atomic_fetch_") || 
-                       funName.contains("atomic_exchange") ) {
+               funName.contains("atomic_exchange")) {
+
+               /* TODO: implement stricter function name checking */
+               if (funName.contains("non"))
+                       return false;
+
                bool isExplicit = funName.contains("_explicit");
                Value *OrigVal = parameters[1];
 
@@ -789,7 +842,12 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
                }
 
                Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
-               Value *val = IRB.CreatePointerCast(OrigVal, Ty);
+               Value *val;
+               if (OrigVal->getType()->isPtrOrPtrVectorTy())
+                       val = IRB.CreatePointerCast(OrigVal, Ty);
+               else
+                       val = IRB.CreateIntCast(OrigVal, Ty, true);
+
                Value *order;
                if (isExplicit)
                        order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
@@ -803,14 +861,36 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
                return true;
        } else if (funName.contains("fetch")) {
-               errs() << "atomic exchange captured. Not implemented yet. ";
+               errs() << "atomic fetch captured. Not implemented yet. ";
                errs() << "See source file :";
                getPosition(CI, IRB, true);
+               return false;
        } else if (funName.contains("exchange") &&
                        !funName.contains("compare_exchange") ) {
-               errs() << "atomic exchange captured. Not implemented yet. ";
-               errs() << "See source file :";
-               getPosition(CI, IRB, true);
+               if (CI->getType()->isPointerTy()) {
+                       // Can not deal with this now
+                       errs() << "atomic exchange captured. Not implemented yet. ";
+                       errs() << "See source file :";
+                       getPosition(CI, IRB, true);
+
+                       return false;
+               }
+
+               Value *OrigVal = parameters[1];
+
+               Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
+               Value *val;
+               if (OrigVal->getType()->isPtrOrPtrVectorTy())
+                       val = IRB.CreatePointerCast(OrigVal, Ty);
+               else
+                       val = IRB.CreateIntCast(OrigVal, Ty, true);
+
+               Value *order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
+               Value *args[] = {ptr, val, order, position};
+               int op = AtomicRMWInst::Xchg;
+               
+               Instruction* funcInst = CallInst::Create(CDSAtomicRMW[op][Idx], args);
+               ReplaceInstWithInst(CI, funcInst);
        }
 
        /* atomic_compare_exchange_*; 
@@ -826,7 +906,18 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
                Value *order_succ, *order_fail;
                if (isExplicit) {
                        order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
-                       order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
+
+                       if (parameters.size() > 4) {
+                               order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
+                       } else {
+                               /* The failure order is not provided */
+                               order_fail = order_succ;
+                               ConstantInt * order_succ_cast = dyn_cast<ConstantInt>(order_succ);
+                               int index = order_succ_cast->getSExtValue();
+
+                               order_fail = ConstantInt::get(OrdTy,
+                                                               AtomicCasFailureOrderIndex(index));
+                       }
                } else  {
                        order_succ = ConstantInt::get(OrdTy, 
                                                        (int) AtomicOrderingCABI::seq_cst);
@@ -849,7 +940,18 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
                Value *order_succ, *order_fail;
                order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
-               order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
+
+               if (parameters.size() > 4) {
+                       order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
+               } else {
+                       /* The failure order is not provided */
+                       order_fail = order_succ;
+                       ConstantInt * order_succ_cast = dyn_cast<ConstantInt>(order_succ);
+                       int index = order_succ_cast->getSExtValue();
+
+                       order_fail = ConstantInt::get(OrdTy,
+                                                       AtomicCasFailureOrderIndex(index));
+               }
 
                Value *args[] = {Addr, CmpOperand, NewOperand, 
                                                        order_succ, order_fail, position};