Compare exchange strong / weak may be called without the failure memory order
[c11llvm.git] / CDSPass.cpp
index 3cbc4bd530509c41e1a3b721b823c8c036eafb87..9a9e95c1e9ef7f6992ad47caf98674327ea686d4 100644 (file)
@@ -42,6 +42,7 @@
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
+#include "llvm/Transforms/Utils/EscapeEnumerator.h"
 #include <vector>
 
 using namespace llvm;
@@ -59,7 +60,7 @@ Value *getPosition( Instruction * I, IRBuilder <> IRB, bool print = false)
        }
 
        if (print) {
-               errs() << position_string;
+               errs() << position_string << "\n";
        }
 
        return IRB.CreateGlobalStringPtr (position_string);
@@ -89,7 +90,7 @@ Type * VoidTy;
 
 static const size_t kNumberOfAccessSizes = 4;
 
-int getAtomicOrderIndex(AtomicOrdering order){
+int getAtomicOrderIndex(AtomicOrdering order) {
        switch (order) {
                case AtomicOrdering::Monotonic: 
                        return (int)AtomicOrderingCABI::relaxed;
@@ -109,15 +110,51 @@ int getAtomicOrderIndex(AtomicOrdering order){
        }
 }
 
+AtomicOrderingCABI indexToAtomicOrder(int index) {
+       switch (index) {
+               case 0:
+                       return AtomicOrderingCABI::relaxed;
+               case 1:
+                       return AtomicOrderingCABI::consume;
+               case 2:
+                       return AtomicOrderingCABI::acquire;
+               case 3:
+                       return AtomicOrderingCABI::release;
+               case 4:
+                       return AtomicOrderingCABI::acq_rel;
+               case 5:
+                       return AtomicOrderingCABI::seq_cst;
+               default:
+                       errs() << "Bad Atomic index\n";
+                       return AtomicOrderingCABI::seq_cst;
+       }
+}
+
+/* According to atomic_base.h: __cmpexch_failure_order */
+int AtomicCasFailureOrderIndex(int index) {
+       AtomicOrderingCABI succ_order = indexToAtomicOrder(index);
+       AtomicOrderingCABI fail_order;
+       if (succ_order == AtomicOrderingCABI::acq_rel)
+               fail_order = AtomicOrderingCABI::acquire;
+       else if (succ_order == AtomicOrderingCABI::release) 
+               fail_order = AtomicOrderingCABI::relaxed;
+       else
+               fail_order = succ_order;
+
+       return (int) fail_order;
+}
+
 namespace {
        struct CDSPass : public FunctionPass {
                static char ID;
                CDSPass() : FunctionPass(ID) {}
                bool runOnFunction(Function &F) override; 
+               StringRef getPassName() const override;
 
        private:
                void initializeCallbacks(Module &M);
                bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL);
+               bool instrumentVolatile(Instruction *I, const DataLayout &DL);
                bool isAtomicCall(Instruction *I);
                bool instrumentAtomic(Instruction *I, const DataLayout &DL);
                bool instrumentAtomicCall(CallInst *CI, const DataLayout &DL);
@@ -133,6 +170,8 @@ namespace {
 
                Constant * CDSLoad[kNumberOfAccessSizes];
                Constant * CDSStore[kNumberOfAccessSizes];
+               Constant * CDSVolatileLoad[kNumberOfAccessSizes];
+               Constant * CDSVolatileStore[kNumberOfAccessSizes];
                Constant * CDSAtomicInit[kNumberOfAccessSizes];
                Constant * CDSAtomicLoad[kNumberOfAccessSizes];
                Constant * CDSAtomicStore[kNumberOfAccessSizes];
@@ -146,6 +185,10 @@ namespace {
        };
 }
 
+StringRef CDSPass::getPassName() const {
+       return "CDSPass";
+}
+
 static bool isVtableAccess(Instruction *I) {
        if (MDNode *Tag = I->getMetadata(LLVMContext::MD_tbaa))
                return Tag->isTBAAVtableAccess();
@@ -167,6 +210,8 @@ void CDSPass::initializeCallbacks(Module &M) {
 
        CDSFuncEntry = M.getOrInsertFunction("cds_func_entry", 
                                                                VoidTy, Int8PtrTy);
+       CDSFuncExit = M.getOrInsertFunction("cds_func_exit", 
+                                                               VoidTy, Int8PtrTy);
 
        // Get the function to call from our untime library.
        for (unsigned i = 0; i < kNumberOfAccessSizes; i++) {
@@ -183,12 +228,18 @@ void CDSPass::initializeCallbacks(Module &M) {
                // void cds_atomic_store8 (void * obj, int atomic_index, uint8_t val)
                SmallString<32> LoadName("cds_load" + BitSizeStr);
                SmallString<32> StoreName("cds_store" + BitSizeStr);
+               SmallString<32> VolatileLoadName("cds_volatile_load" + BitSizeStr);
+               SmallString<32> VolatileStoreName("cds_volatile_store" + BitSizeStr);
                SmallString<32> AtomicInitName("cds_atomic_init" + BitSizeStr);
                SmallString<32> AtomicLoadName("cds_atomic_load" + BitSizeStr);
                SmallString<32> AtomicStoreName("cds_atomic_store" + BitSizeStr);
 
                CDSLoad[i]  = M.getOrInsertFunction(LoadName, VoidTy, PtrTy);
                CDSStore[i] = M.getOrInsertFunction(StoreName, VoidTy, PtrTy);
+               CDSVolatileLoad[i]  = M.getOrInsertFunction(VolatileLoadName,
+                                                                       Ty, PtrTy, Int8PtrTy);
+               CDSVolatileStore[i] = M.getOrInsertFunction(VolatileStoreName, 
+                                                                       VoidTy, PtrTy, Ty, Int8PtrTy);
                CDSAtomicInit[i] = M.getOrInsertFunction(AtomicInitName, 
                                                                VoidTy, PtrTy, Ty, Int8PtrTy);
                CDSAtomicLoad[i]  = M.getOrInsertFunction(AtomicLoadName, 
@@ -308,12 +359,14 @@ bool CDSPass::runOnFunction(Function &F) {
 
                SmallVector<Instruction*, 8> AllLoadsAndStores;
                SmallVector<Instruction*, 8> LocalLoadsAndStores;
+               SmallVector<Instruction*, 8> VolatileLoadsAndStores;
                SmallVector<Instruction*, 8> AtomicAccesses;
 
                std::vector<Instruction *> worklist;
 
                bool Res = false;
                bool HasAtomic = false;
+               bool HasVolatile = false;
                const DataLayout &DL = F.getParent()->getDataLayout();
 
                // errs() << "--- " << F.getName() << "---\n";
@@ -324,7 +377,15 @@ bool CDSPass::runOnFunction(Function &F) {
                                        AtomicAccesses.push_back(&I);
                                        HasAtomic = true;
                                } else if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
-                                       LocalLoadsAndStores.push_back(&I);
+                                       LoadInst *LI = dyn_cast<LoadInst>(&I);
+                                       StoreInst *SI = dyn_cast<StoreInst>(&I);
+                                       bool isVolatile = ( LI ? LI->isVolatile() : SI->isVolatile() );
+
+                                       if (isVolatile) {
+                                               VolatileLoadsAndStores.push_back(&I);
+                                               HasVolatile = true;
+                                       } else
+                                               LocalLoadsAndStores.push_back(&I);
                                } else if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
                                        // not implemented yet
                                }
@@ -334,8 +395,11 @@ bool CDSPass::runOnFunction(Function &F) {
                }
 
                for (auto Inst : AllLoadsAndStores) {
-                       // Res |= instrumentLoadOrStore(Inst, DL);
-                       // errs() << "load and store are replaced\n";
+                       Res |= instrumentLoadOrStore(Inst, DL);
+               }
+
+               for (auto Inst : VolatileLoadsAndStores) {
+                       Res |= instrumentVolatile(Inst, DL);
                }
 
                for (auto Inst : AtomicAccesses) {
@@ -343,24 +407,22 @@ bool CDSPass::runOnFunction(Function &F) {
                }
 
                // only instrument functions that contain atomics
-               if (Res && HasAtomic) {
-                       /*
+               if (Res && ( HasAtomic || HasVolatile) ) {
                        IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI());
+                       /* Unused for now
                        Value *ReturnAddress = IRB.CreateCall(
                                Intrinsic::getDeclaration(F.getParent(), Intrinsic::returnaddress),
                                IRB.getInt32(0));
+                       */
 
                        Value * FuncName = IRB.CreateGlobalStringPtr(F.getName());
-                       */
-                       //errs() << "function name: " << F.getName() << "\n";
-                       //IRB.CreateCall(CDSFuncEntry, FuncName);
+                       IRB.CreateCall(CDSFuncEntry, FuncName);
 
-/*
-                       EscapeEnumerator EE(F, "tsan_cleanup", ClHandleCxxExceptions);
+                       EscapeEnumerator EE(F, "cds_cleanup", true);
                        while (IRBuilder<> *AtExit = EE.Next()) {
-                         AtExit->CreateCall(TsanFuncExit, {});
+                         AtExit->CreateCall(CDSFuncExit, FuncName);
                        }
-*/
+
                        Res = true;
                }
        }
@@ -426,6 +488,8 @@ bool CDSPass::instrumentLoadOrStore(Instruction *I,
        return false;
 
        int Idx = getMemoryAccessFuncIndex(Addr, DL);
+       if (Idx < 0)
+               return false;
 
 //  not supported by CDS yet
 /*  if (IsWrite && isVtableAccess(I)) {
@@ -462,10 +526,8 @@ bool CDSPass::instrumentLoadOrStore(Instruction *I,
 
        if ( ArgType != Int8PtrTy && ArgType != Int16PtrTy && 
                        ArgType != Int32PtrTy && ArgType != Int64PtrTy ) {
-               //errs() << "A load or store of type ";
-               //errs() << *ArgType;
-               //errs() << " is passed in\n";
-               return false;   // if other types of load or stores are passed in
+               // if other types of load or stores are passed in
+               return false;   
        }
        IRB.CreateCall(OnAccessFunc, IRB.CreatePointerCast(Addr, Addr->getType()));
        if (IsWrite) NumInstrumentedWrites++;
@@ -473,6 +535,38 @@ bool CDSPass::instrumentLoadOrStore(Instruction *I,
        return true;
 }
 
+bool CDSPass::instrumentVolatile(Instruction * I, const DataLayout &DL) {
+       IRBuilder<> IRB(I);
+       Value *position = getPosition(I, IRB);
+
+       if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
+               assert( LI->isVolatile() );
+               Value *Addr = LI->getPointerOperand();
+               int Idx=getMemoryAccessFuncIndex(Addr, DL);
+               if (Idx < 0)
+                       return false;
+
+               Value *args[] = {Addr, position};
+               Instruction* funcInst=CallInst::Create(CDSVolatileLoad[Idx], args);
+               ReplaceInstWithInst(LI, funcInst);
+       } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
+               assert( SI->isVolatile() );
+               Value *Addr = SI->getPointerOperand();
+               int Idx=getMemoryAccessFuncIndex(Addr, DL);
+               if (Idx < 0)
+                       return false;
+
+               Value *val = SI->getValueOperand();
+               Value *args[] = {Addr, val, position};
+               Instruction* funcInst=CallInst::Create(CDSVolatileStore[Idx], args);
+               ReplaceInstWithInst(SI, funcInst);
+       } else {
+               return false;
+       }
+
+       return true;
+}
+
 bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
        IRBuilder<> IRB(I);
 
@@ -485,6 +579,9 @@ bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
        if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
                Value *Addr = LI->getPointerOperand();
                int Idx=getMemoryAccessFuncIndex(Addr, DL);
+               if (Idx < 0)
+                       return false;
+
                int atomic_order_index = getAtomicOrderIndex(LI->getOrdering());
                Value *order = ConstantInt::get(OrdTy, atomic_order_index);
                Value *args[] = {Addr, order, position};
@@ -493,6 +590,9 @@ bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
        } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
                Value *Addr = SI->getPointerOperand();
                int Idx=getMemoryAccessFuncIndex(Addr, DL);
+               if (Idx < 0)
+                       return false;
+
                int atomic_order_index = getAtomicOrderIndex(SI->getOrdering());
                Value *val = SI->getValueOperand();
                Value *order = ConstantInt::get(OrdTy, atomic_order_index);
@@ -502,6 +602,9 @@ bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
        } else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) {
                Value *Addr = RMWI->getPointerOperand();
                int Idx=getMemoryAccessFuncIndex(Addr, DL);
+               if (Idx < 0)
+                       return false;
+
                int atomic_order_index = getAtomicOrderIndex(RMWI->getOrdering());
                Value *val = RMWI->getValOperand();
                Value *order = ConstantInt::get(OrdTy, atomic_order_index);
@@ -513,6 +616,8 @@ bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
 
                Value *Addr = CASI->getPointerOperand();
                int Idx=getMemoryAccessFuncIndex(Addr, DL);
+               if (Idx < 0)
+                       return false;
 
                const unsigned ByteSize = 1U << Idx;
                const unsigned BitSize = ByteSize * 8;
@@ -613,8 +718,15 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
        // atomic_init; args = {obj, order}
        if (funName.contains("atomic_init")) {
+               Value *OrigVal = parameters[1];
+
                Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
-               Value *val = IRB.CreateBitOrPointerCast(parameters[1], Ty);
+               Value *val;
+               if (OrigVal->getType()->isPtrOrPtrVectorTy())
+                       val = IRB.CreatePointerCast(OrigVal, Ty);
+               else
+                       val = IRB.CreateIntCast(OrigVal, Ty, true);
+
                Value *args[] = {ptr, val, position};
 
                Instruction* funcInst = CallInst::Create(CDSAtomicInit[Idx], args);
@@ -680,13 +792,18 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
                return true;
        } else if (funName.contains("atomic") && 
-                                       funName.contains("EEEE5store") ) {
+                                       funName.contains("store") ) {
                // does this version of call always have an atomic order as an argument?
                Value *OrigVal = parameters[1];
 
                Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
-               Value *val = IRB.CreatePointerCast(OrigVal, Ty);
-               Value *order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
+               Value *val;
+               if (OrigVal->getType()->isPtrOrPtrVectorTy())
+                       val = IRB.CreatePointerCast(OrigVal, Ty);
+               else
+                       val = IRB.CreateIntCast(OrigVal, Ty, true);
+
+               Value *order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
                Value *args[] = {ptr, val, order, position};
 
                Instruction* funcInst = CallInst::Create(CDSAtomicStore[Idx], args);
@@ -697,7 +814,12 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
        // atomic_fetch_*; args = {obj, val, order}
        if (funName.contains("atomic_fetch_") || 
-                       funName.contains("atomic_exchange") ) {
+               funName.contains("atomic_exchange")) {
+
+               /* TODO: implement stricter function name checking */
+               if (funName.contains("non"))
+                       return false;
+
                bool isExplicit = funName.contains("_explicit");
                Value *OrigVal = parameters[1];
 
@@ -720,7 +842,12 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
                }
 
                Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
-               Value *val = IRB.CreatePointerCast(OrigVal, Ty);
+               Value *val;
+               if (OrigVal->getType()->isPtrOrPtrVectorTy())
+                       val = IRB.CreatePointerCast(OrigVal, Ty);
+               else
+                       val = IRB.CreateIntCast(OrigVal, Ty, true);
+
                Value *order;
                if (isExplicit)
                        order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
@@ -734,14 +861,36 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
                return true;
        } else if (funName.contains("fetch")) {
-               errs() << "atomic exchange captured. Not implemented yet. ";
+               errs() << "atomic fetch captured. Not implemented yet. ";
                errs() << "See source file :";
                getPosition(CI, IRB, true);
+               return false;
        } else if (funName.contains("exchange") &&
                        !funName.contains("compare_exchange") ) {
-               errs() << "atomic exchange captured. Not implemented yet. ";
-               errs() << "See source file :";
-               getPosition(CI, IRB, true);
+               if (CI->getType()->isPointerTy()) {
+                       // Can not deal with this now
+                       errs() << "atomic exchange captured. Not implemented yet. ";
+                       errs() << "See source file :";
+                       getPosition(CI, IRB, true);
+
+                       return false;
+               }
+
+               Value *OrigVal = parameters[1];
+
+               Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
+               Value *val;
+               if (OrigVal->getType()->isPtrOrPtrVectorTy())
+                       val = IRB.CreatePointerCast(OrigVal, Ty);
+               else
+                       val = IRB.CreateIntCast(OrigVal, Ty, true);
+
+               Value *order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
+               Value *args[] = {ptr, val, order, position};
+               int op = AtomicRMWInst::Xchg;
+               
+               Instruction* funcInst = CallInst::Create(CDSAtomicRMW[op][Idx], args);
+               ReplaceInstWithInst(CI, funcInst);
        }
 
        /* atomic_compare_exchange_*; 
@@ -757,7 +906,18 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
                Value *order_succ, *order_fail;
                if (isExplicit) {
                        order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
-                       order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
+
+                       if (parameters.size() > 4) {
+                               order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
+                       } else {
+                               /* The failure order is not provided */
+                               order_fail = order_succ;
+                               ConstantInt * order_succ_cast = dyn_cast<ConstantInt>(order_succ);
+                               int index = order_succ_cast->getSExtValue();
+
+                               order_fail = ConstantInt::get(OrdTy,
+                                                               AtomicCasFailureOrderIndex(index));
+                       }
                } else  {
                        order_succ = ConstantInt::get(OrdTy, 
                                                        (int) AtomicOrderingCABI::seq_cst);
@@ -780,7 +940,18 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
                Value *order_succ, *order_fail;
                order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
-               order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
+
+               if (parameters.size() > 4) {
+                       order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
+               } else {
+                       /* The failure order is not provided */
+                       order_fail = order_succ;
+                       ConstantInt * order_succ_cast = dyn_cast<ConstantInt>(order_succ);
+                       int index = order_succ_cast->getSExtValue();
+
+                       order_fail = ConstantInt::get(OrdTy,
+                                                       AtomicCasFailureOrderIndex(index));
+               }
 
                Value *args[] = {Addr, CmpOperand, NewOperand, 
                                                        order_succ, order_fail, position};
@@ -806,7 +977,10 @@ int CDSPass::getMemoryAccessFuncIndex(Value *Addr,
                return -1;
        }
        size_t Idx = countTrailingZeros(TypeSize / 8);
-       assert(Idx < kNumberOfAccessSizes);
+       //assert(Idx < kNumberOfAccessSizes);
+       if (Idx >= kNumberOfAccessSizes) {
+               return -1;
+       }
        return Idx;
 }
 
@@ -818,6 +992,13 @@ static void registerCDSPass(const PassManagerBuilder &,
                                                        legacy::PassManagerBase &PM) {
        PM.add(new CDSPass());
 }
+
+/* Enable the pass when opt level is greater than 0 */
+static RegisterStandardPasses 
+       RegisterMyPass1(PassManagerBuilder::EP_OptimizerLast,
+registerCDSPass);
+
+/* Enable the pass when opt level is 0 */
 static RegisterStandardPasses 
-       RegisterMyPass(PassManagerBuilder::EP_OptimizerLast,
+       RegisterMyPass2(PassManagerBuilder::EP_EnabledOnOptLevel0,
 registerCDSPass);