Fix a bug related to atomic bool
[c11llvm.git] / CDSPass.cpp
1 //===-- CDSPass.cpp - xxx -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 // This file is distributed under the University of Illinois Open Source
7 // License. See LICENSE.TXT for details.
8 //
9 //===----------------------------------------------------------------------===//
10 //
11 // This file is a modified version of ThreadSanitizer.cpp, a part of a race detector.
12 //
13 // The tool is under development, for the details about previous versions see
14 // http://code.google.com/p/data-race-test
15 //
16 // The instrumentation phase is quite simple:
17 //   - Insert calls to run-time library before every memory access.
18 //      - Optimizations may apply to avoid instrumenting some of the accesses.
19 //   - Insert calls at function entry/exit.
20 // The rest is handled by the run-time library.
21 //===----------------------------------------------------------------------===//
22
23 #include "llvm/ADT/Statistic.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/Analysis/ValueTracking.h"
27 #include "llvm/Analysis/CaptureTracking.h"
28 #include "llvm/IR/BasicBlock.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/LegacyPassManager.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/PassManager.h"
36 #include "llvm/Pass.h"
37 #include "llvm/ProfileData/InstrProf.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include "llvm/Support/AtomicOrdering.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Transforms/Scalar.h"
42 #include "llvm/Transforms/Utils/Local.h"
43 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
44 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
45 #include "llvm/Transforms/Utils/EscapeEnumerator.h"
46 #include <vector>
47
48 using namespace llvm;
49
50 #define DEBUG_TYPE "CDS"
51 #include <llvm/IR/DebugLoc.h>
52
53 Value *getPosition( Instruction * I, IRBuilder <> IRB, bool print = false)
54 {
55         const DebugLoc & debug_location = I->getDebugLoc ();
56         std::string position_string;
57         {
58                 llvm::raw_string_ostream position_stream (position_string);
59                 debug_location . print (position_stream);
60         }
61
62         if (print) {
63                 errs() << position_string << "\n";
64         }
65
66         return IRB.CreateGlobalStringPtr (position_string);
67 }
68
69 STATISTIC(NumInstrumentedReads, "Number of instrumented reads");
70 STATISTIC(NumInstrumentedWrites, "Number of instrumented writes");
71 STATISTIC(NumAccessesWithBadSize, "Number of accesses with bad size");
72 // STATISTIC(NumInstrumentedVtableWrites, "Number of vtable ptr writes");
73 // STATISTIC(NumInstrumentedVtableReads, "Number of vtable ptr reads");
74
75 STATISTIC(NumOmittedReadsBeforeWrite,
76           "Number of reads ignored due to following writes");
77 STATISTIC(NumOmittedReadsFromConstantGlobals,
78           "Number of reads from constant globals");
79 STATISTIC(NumOmittedReadsFromVtable, "Number of vtable reads");
80 STATISTIC(NumOmittedNonCaptured, "Number of accesses ignored due to capturing");
81
82 Type * OrdTy;
83
84 Type * Int8PtrTy;
85 Type * Int16PtrTy;
86 Type * Int32PtrTy;
87 Type * Int64PtrTy;
88
89 Type * VoidTy;
90
91 static const size_t kNumberOfAccessSizes = 4;
92
93 int getAtomicOrderIndex(AtomicOrdering order){
94         switch (order) {
95                 case AtomicOrdering::Monotonic: 
96                         return (int)AtomicOrderingCABI::relaxed;
97                 //  case AtomicOrdering::Consume:         // not specified yet
98                 //    return AtomicOrderingCABI::consume;
99                 case AtomicOrdering::Acquire: 
100                         return (int)AtomicOrderingCABI::acquire;
101                 case AtomicOrdering::Release: 
102                         return (int)AtomicOrderingCABI::release;
103                 case AtomicOrdering::AcquireRelease: 
104                         return (int)AtomicOrderingCABI::acq_rel;
105                 case AtomicOrdering::SequentiallyConsistent: 
106                         return (int)AtomicOrderingCABI::seq_cst;
107                 default:
108                         // unordered or Not Atomic
109                         return -1;
110         }
111 }
112
113 namespace {
114         struct CDSPass : public FunctionPass {
115                 static char ID;
116                 CDSPass() : FunctionPass(ID) {}
117                 bool runOnFunction(Function &F) override; 
118                 StringRef getPassName() const override;
119
120         private:
121                 void initializeCallbacks(Module &M);
122                 bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL);
123                 bool instrumentVolatile(Instruction *I, const DataLayout &DL);
124                 bool isAtomicCall(Instruction *I);
125                 bool instrumentAtomic(Instruction *I, const DataLayout &DL);
126                 bool instrumentAtomicCall(CallInst *CI, const DataLayout &DL);
127                 void chooseInstructionsToInstrument(SmallVectorImpl<Instruction *> &Local,
128                                                                                         SmallVectorImpl<Instruction *> &All,
129                                                                                         const DataLayout &DL);
130                 bool addrPointsToConstantData(Value *Addr);
131                 int getMemoryAccessFuncIndex(Value *Addr, const DataLayout &DL);
132
133                 // Callbacks to run-time library are computed in doInitialization.
134                 Constant * CDSFuncEntry;
135                 Constant * CDSFuncExit;
136
137                 Constant * CDSLoad[kNumberOfAccessSizes];
138                 Constant * CDSStore[kNumberOfAccessSizes];
139                 Constant * CDSVolatileLoad[kNumberOfAccessSizes];
140                 Constant * CDSVolatileStore[kNumberOfAccessSizes];
141                 Constant * CDSAtomicInit[kNumberOfAccessSizes];
142                 Constant * CDSAtomicLoad[kNumberOfAccessSizes];
143                 Constant * CDSAtomicStore[kNumberOfAccessSizes];
144                 Constant * CDSAtomicRMW[AtomicRMWInst::LAST_BINOP + 1][kNumberOfAccessSizes];
145                 Constant * CDSAtomicCAS_V1[kNumberOfAccessSizes];
146                 Constant * CDSAtomicCAS_V2[kNumberOfAccessSizes];
147                 Constant * CDSAtomicThreadFence;
148
149                 std::vector<StringRef> AtomicFuncNames;
150                 std::vector<StringRef> PartialAtomicFuncNames;
151         };
152 }
153
154 StringRef CDSPass::getPassName() const {
155         return "CDSPass";
156 }
157
158 static bool isVtableAccess(Instruction *I) {
159         if (MDNode *Tag = I->getMetadata(LLVMContext::MD_tbaa))
160                 return Tag->isTBAAVtableAccess();
161         return false;
162 }
163
164 void CDSPass::initializeCallbacks(Module &M) {
165         LLVMContext &Ctx = M.getContext();
166
167         Type * Int1Ty = Type::getInt1Ty(Ctx);
168         OrdTy = Type::getInt32Ty(Ctx);
169
170         Int8PtrTy  = Type::getInt8PtrTy(Ctx);
171         Int16PtrTy = Type::getInt16PtrTy(Ctx);
172         Int32PtrTy = Type::getInt32PtrTy(Ctx);
173         Int64PtrTy = Type::getInt64PtrTy(Ctx);
174
175         VoidTy = Type::getVoidTy(Ctx);
176
177         CDSFuncEntry = M.getOrInsertFunction("cds_func_entry", 
178                                                                 VoidTy, Int8PtrTy);
179         CDSFuncExit = M.getOrInsertFunction("cds_func_exit", 
180                                                                 VoidTy, Int8PtrTy);
181
182         // Get the function to call from our untime library.
183         for (unsigned i = 0; i < kNumberOfAccessSizes; i++) {
184                 const unsigned ByteSize = 1U << i;
185                 const unsigned BitSize = ByteSize * 8;
186
187                 std::string ByteSizeStr = utostr(ByteSize);
188                 std::string BitSizeStr = utostr(BitSize);
189
190                 Type *Ty = Type::getIntNTy(Ctx, BitSize);
191                 Type *PtrTy = Ty->getPointerTo();
192
193                 // uint8_t cds_atomic_load8 (void * obj, int atomic_index)
194                 // void cds_atomic_store8 (void * obj, int atomic_index, uint8_t val)
195                 SmallString<32> LoadName("cds_load" + BitSizeStr);
196                 SmallString<32> StoreName("cds_store" + BitSizeStr);
197                 SmallString<32> VolatileLoadName("cds_volatile_load" + BitSizeStr);
198                 SmallString<32> VolatileStoreName("cds_volatile_store" + BitSizeStr);
199                 SmallString<32> AtomicInitName("cds_atomic_init" + BitSizeStr);
200                 SmallString<32> AtomicLoadName("cds_atomic_load" + BitSizeStr);
201                 SmallString<32> AtomicStoreName("cds_atomic_store" + BitSizeStr);
202
203                 CDSLoad[i]  = M.getOrInsertFunction(LoadName, VoidTy, PtrTy);
204                 CDSStore[i] = M.getOrInsertFunction(StoreName, VoidTy, PtrTy);
205                 CDSVolatileLoad[i]  = M.getOrInsertFunction(VolatileLoadName,
206                                                                         Ty, PtrTy, Int8PtrTy);
207                 CDSVolatileStore[i] = M.getOrInsertFunction(VolatileStoreName, 
208                                                                         VoidTy, PtrTy, Ty, Int8PtrTy);
209                 CDSAtomicInit[i] = M.getOrInsertFunction(AtomicInitName, 
210                                                                 VoidTy, PtrTy, Ty, Int8PtrTy);
211                 CDSAtomicLoad[i]  = M.getOrInsertFunction(AtomicLoadName, 
212                                                                 Ty, PtrTy, OrdTy, Int8PtrTy);
213                 CDSAtomicStore[i] = M.getOrInsertFunction(AtomicStoreName, 
214                                                                 VoidTy, PtrTy, Ty, OrdTy, Int8PtrTy);
215
216                 for (int op = AtomicRMWInst::FIRST_BINOP; 
217                         op <= AtomicRMWInst::LAST_BINOP; ++op) {
218                         CDSAtomicRMW[op][i] = nullptr;
219                         std::string NamePart;
220
221                         if (op == AtomicRMWInst::Xchg)
222                                 NamePart = "_exchange";
223                         else if (op == AtomicRMWInst::Add) 
224                                 NamePart = "_fetch_add";
225                         else if (op == AtomicRMWInst::Sub)
226                                 NamePart = "_fetch_sub";
227                         else if (op == AtomicRMWInst::And)
228                                 NamePart = "_fetch_and";
229                         else if (op == AtomicRMWInst::Or)
230                                 NamePart = "_fetch_or";
231                         else if (op == AtomicRMWInst::Xor)
232                                 NamePart = "_fetch_xor";
233                         else
234                                 continue;
235
236                         SmallString<32> AtomicRMWName("cds_atomic" + NamePart + BitSizeStr);
237                         CDSAtomicRMW[op][i] = M.getOrInsertFunction(AtomicRMWName, 
238                                                                                 Ty, PtrTy, Ty, OrdTy, Int8PtrTy);
239                 }
240
241                 // only supportes strong version
242                 SmallString<32> AtomicCASName_V1("cds_atomic_compare_exchange" + BitSizeStr + "_v1");
243                 SmallString<32> AtomicCASName_V2("cds_atomic_compare_exchange" + BitSizeStr + "_v2");
244                 CDSAtomicCAS_V1[i] = M.getOrInsertFunction(AtomicCASName_V1, 
245                                                                 Ty, PtrTy, Ty, Ty, OrdTy, OrdTy, Int8PtrTy);
246                 CDSAtomicCAS_V2[i] = M.getOrInsertFunction(AtomicCASName_V2, 
247                                                                 Int1Ty, PtrTy, PtrTy, Ty, OrdTy, OrdTy, Int8PtrTy);
248         }
249
250         CDSAtomicThreadFence = M.getOrInsertFunction("cds_atomic_thread_fence", 
251                                                                                                         VoidTy, OrdTy, Int8PtrTy);
252 }
253
254 static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) {
255         // Peel off GEPs and BitCasts.
256         Addr = Addr->stripInBoundsOffsets();
257
258         if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
259                 if (GV->hasSection()) {
260                         StringRef SectionName = GV->getSection();
261                         // Check if the global is in the PGO counters section.
262                         auto OF = Triple(M->getTargetTriple()).getObjectFormat();
263                         if (SectionName.endswith(
264                               getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false)))
265                                 return false;
266                 }
267
268                 // Check if the global is private gcov data.
269                 if (GV->getName().startswith("__llvm_gcov") ||
270                 GV->getName().startswith("__llvm_gcda"))
271                 return false;
272         }
273
274         // Do not instrument acesses from different address spaces; we cannot deal
275         // with them.
276         if (Addr) {
277                 Type *PtrTy = cast<PointerType>(Addr->getType()->getScalarType());
278                 if (PtrTy->getPointerAddressSpace() != 0)
279                         return false;
280         }
281
282         return true;
283 }
284
285 bool CDSPass::addrPointsToConstantData(Value *Addr) {
286         // If this is a GEP, just analyze its pointer operand.
287         if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Addr))
288                 Addr = GEP->getPointerOperand();
289
290         if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
291                 if (GV->isConstant()) {
292                         // Reads from constant globals can not race with any writes.
293                         NumOmittedReadsFromConstantGlobals++;
294                         return true;
295                 }
296         } else if (LoadInst *L = dyn_cast<LoadInst>(Addr)) {
297                 if (isVtableAccess(L)) {
298                         // Reads from a vtable pointer can not race with any writes.
299                         NumOmittedReadsFromVtable++;
300                         return true;
301                 }
302         }
303         return false;
304 }
305
306 bool CDSPass::runOnFunction(Function &F) {
307         if (F.getName() == "main") {
308                 F.setName("user_main");
309                 errs() << "main replaced by user_main\n";
310         }
311
312         if (true) {
313                 initializeCallbacks( *F.getParent() );
314
315                 AtomicFuncNames = 
316                 {
317                         "atomic_init", "atomic_load", "atomic_store", 
318                         "atomic_fetch_", "atomic_exchange", "atomic_compare_exchange_"
319                 };
320
321                 PartialAtomicFuncNames = 
322                 { 
323                         "load", "store", "fetch", "exchange", "compare_exchange_" 
324                 };
325
326                 SmallVector<Instruction*, 8> AllLoadsAndStores;
327                 SmallVector<Instruction*, 8> LocalLoadsAndStores;
328                 SmallVector<Instruction*, 8> VolatileLoadsAndStores;
329                 SmallVector<Instruction*, 8> AtomicAccesses;
330
331                 std::vector<Instruction *> worklist;
332
333                 bool Res = false;
334                 bool HasAtomic = false;
335                 bool HasVolatile = false;
336                 const DataLayout &DL = F.getParent()->getDataLayout();
337
338                 // errs() << "--- " << F.getName() << "---\n";
339
340                 for (auto &B : F) {
341                         for (auto &I : B) {
342                                 if ( (&I)->isAtomic() || isAtomicCall(&I) ) {
343                                         AtomicAccesses.push_back(&I);
344                                         HasAtomic = true;
345                                 } else if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
346                                         LoadInst *LI = dyn_cast<LoadInst>(&I);
347                                         StoreInst *SI = dyn_cast<StoreInst>(&I);
348                                         bool isVolatile = ( LI ? LI->isVolatile() : SI->isVolatile() );
349
350                                         if (isVolatile) {
351                                                 VolatileLoadsAndStores.push_back(&I);
352                                                 HasVolatile = true;
353                                         } else
354                                                 LocalLoadsAndStores.push_back(&I);
355                                 } else if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
356                                         // not implemented yet
357                                 }
358                         }
359
360                         chooseInstructionsToInstrument(LocalLoadsAndStores, AllLoadsAndStores, DL);
361                 }
362
363                 for (auto Inst : AllLoadsAndStores) {
364                         Res |= instrumentLoadOrStore(Inst, DL);
365                 }
366
367                 for (auto Inst : VolatileLoadsAndStores) {
368                         Res |= instrumentVolatile(Inst, DL);
369                 }
370
371                 for (auto Inst : AtomicAccesses) {
372                         Res |= instrumentAtomic(Inst, DL);
373                 }
374
375                 // only instrument functions that contain atomics
376                 if (Res && ( HasAtomic || HasVolatile) ) {
377                         IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI());
378                         /* Unused for now
379                         Value *ReturnAddress = IRB.CreateCall(
380                                 Intrinsic::getDeclaration(F.getParent(), Intrinsic::returnaddress),
381                                 IRB.getInt32(0));
382                         */
383
384                         Value * FuncName = IRB.CreateGlobalStringPtr(F.getName());
385                         IRB.CreateCall(CDSFuncEntry, FuncName);
386
387                         EscapeEnumerator EE(F, "cds_cleanup", true);
388                         while (IRBuilder<> *AtExit = EE.Next()) {
389                           AtExit->CreateCall(CDSFuncExit, FuncName);
390                         }
391
392                         Res = true;
393                 }
394         }
395
396         return false;
397 }
398
399 void CDSPass::chooseInstructionsToInstrument(
400         SmallVectorImpl<Instruction *> &Local, SmallVectorImpl<Instruction *> &All,
401         const DataLayout &DL) {
402         SmallPtrSet<Value*, 8> WriteTargets;
403         // Iterate from the end.
404         for (Instruction *I : reverse(Local)) {
405                 if (StoreInst *Store = dyn_cast<StoreInst>(I)) {
406                         Value *Addr = Store->getPointerOperand();
407                         if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
408                                 continue;
409                         WriteTargets.insert(Addr);
410                 } else {
411                         LoadInst *Load = cast<LoadInst>(I);
412                         Value *Addr = Load->getPointerOperand();
413                         if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
414                                 continue;
415                         if (WriteTargets.count(Addr)) {
416                                 // We will write to this temp, so no reason to analyze the read.
417                                 NumOmittedReadsBeforeWrite++;
418                                 continue;
419                         }
420                         if (addrPointsToConstantData(Addr)) {
421                                 // Addr points to some constant data -- it can not race with any writes.
422                                 continue;
423                         }
424                 }
425                 Value *Addr = isa<StoreInst>(*I)
426                         ? cast<StoreInst>(I)->getPointerOperand()
427                         : cast<LoadInst>(I)->getPointerOperand();
428                 if (isa<AllocaInst>(GetUnderlyingObject(Addr, DL)) &&
429                                 !PointerMayBeCaptured(Addr, true, true)) {
430                         // The variable is addressable but not captured, so it cannot be
431                         // referenced from a different thread and participate in a data race
432                         // (see llvm/Analysis/CaptureTracking.h for details).
433                         NumOmittedNonCaptured++;
434                         continue;
435                 }
436                 All.push_back(I);
437         }
438         Local.clear();
439 }
440
441
442 bool CDSPass::instrumentLoadOrStore(Instruction *I,
443                                                                         const DataLayout &DL) {
444         IRBuilder<> IRB(I);
445         bool IsWrite = isa<StoreInst>(*I);
446         Value *Addr = IsWrite
447                 ? cast<StoreInst>(I)->getPointerOperand()
448                 : cast<LoadInst>(I)->getPointerOperand();
449
450         // swifterror memory addresses are mem2reg promoted by instruction selection.
451         // As such they cannot have regular uses like an instrumentation function and
452         // it makes no sense to track them as memory.
453         if (Addr->isSwiftError())
454         return false;
455
456         int Idx = getMemoryAccessFuncIndex(Addr, DL);
457         if (Idx < 0)
458                 return false;
459
460 //  not supported by CDS yet
461 /*  if (IsWrite && isVtableAccess(I)) {
462     LLVM_DEBUG(dbgs() << "  VPTR : " << *I << "\n");
463     Value *StoredValue = cast<StoreInst>(I)->getValueOperand();
464     // StoredValue may be a vector type if we are storing several vptrs at once.
465     // In this case, just take the first element of the vector since this is
466     // enough to find vptr races.
467     if (isa<VectorType>(StoredValue->getType()))
468       StoredValue = IRB.CreateExtractElement(
469           StoredValue, ConstantInt::get(IRB.getInt32Ty(), 0));
470     if (StoredValue->getType()->isIntegerTy())
471       StoredValue = IRB.CreateIntToPtr(StoredValue, IRB.getInt8PtrTy());
472     // Call TsanVptrUpdate.
473     IRB.CreateCall(TsanVptrUpdate,
474                    {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()),
475                     IRB.CreatePointerCast(StoredValue, IRB.getInt8PtrTy())});
476     NumInstrumentedVtableWrites++;
477     return true;
478   }
479
480   if (!IsWrite && isVtableAccess(I)) {
481     IRB.CreateCall(TsanVptrLoad,
482                    IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()));
483     NumInstrumentedVtableReads++;
484     return true;
485   }
486 */
487
488         Value *OnAccessFunc = nullptr;
489         OnAccessFunc = IsWrite ? CDSStore[Idx] : CDSLoad[Idx];
490
491         Type *ArgType = IRB.CreatePointerCast(Addr, Addr->getType())->getType();
492
493         if ( ArgType != Int8PtrTy && ArgType != Int16PtrTy && 
494                         ArgType != Int32PtrTy && ArgType != Int64PtrTy ) {
495                 // if other types of load or stores are passed in
496                 return false;   
497         }
498         IRB.CreateCall(OnAccessFunc, IRB.CreatePointerCast(Addr, Addr->getType()));
499         if (IsWrite) NumInstrumentedWrites++;
500         else         NumInstrumentedReads++;
501         return true;
502 }
503
504 bool CDSPass::instrumentVolatile(Instruction * I, const DataLayout &DL) {
505         IRBuilder<> IRB(I);
506         Value *position = getPosition(I, IRB);
507
508         if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
509                 assert( LI->isVolatile() );
510                 Value *Addr = LI->getPointerOperand();
511                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
512                 if (Idx < 0)
513                         return false;
514
515                 Value *args[] = {Addr, position};
516                 Instruction* funcInst=CallInst::Create(CDSVolatileLoad[Idx], args);
517                 ReplaceInstWithInst(LI, funcInst);
518         } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
519                 assert( SI->isVolatile() );
520                 Value *Addr = SI->getPointerOperand();
521                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
522                 if (Idx < 0)
523                         return false;
524
525                 Value *val = SI->getValueOperand();
526                 Value *args[] = {Addr, val, position};
527                 Instruction* funcInst=CallInst::Create(CDSVolatileStore[Idx], args);
528                 ReplaceInstWithInst(SI, funcInst);
529         } else {
530                 return false;
531         }
532
533         return true;
534 }
535
536 bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
537         IRBuilder<> IRB(I);
538
539         if (auto *CI = dyn_cast<CallInst>(I)) {
540                 return instrumentAtomicCall(CI, DL);
541         }
542
543         Value *position = getPosition(I, IRB);
544
545         if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
546                 Value *Addr = LI->getPointerOperand();
547                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
548                 if (Idx < 0)
549                         return false;
550
551                 int atomic_order_index = getAtomicOrderIndex(LI->getOrdering());
552                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
553                 Value *args[] = {Addr, order, position};
554                 Instruction* funcInst=CallInst::Create(CDSAtomicLoad[Idx], args);
555                 ReplaceInstWithInst(LI, funcInst);
556         } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
557                 Value *Addr = SI->getPointerOperand();
558                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
559                 if (Idx < 0)
560                         return false;
561
562                 int atomic_order_index = getAtomicOrderIndex(SI->getOrdering());
563                 Value *val = SI->getValueOperand();
564                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
565                 Value *args[] = {Addr, val, order, position};
566                 Instruction* funcInst=CallInst::Create(CDSAtomicStore[Idx], args);
567                 ReplaceInstWithInst(SI, funcInst);
568         } else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) {
569                 Value *Addr = RMWI->getPointerOperand();
570                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
571                 if (Idx < 0)
572                         return false;
573
574                 int atomic_order_index = getAtomicOrderIndex(RMWI->getOrdering());
575                 Value *val = RMWI->getValOperand();
576                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
577                 Value *args[] = {Addr, val, order, position};
578                 Instruction* funcInst = CallInst::Create(CDSAtomicRMW[RMWI->getOperation()][Idx], args);
579                 ReplaceInstWithInst(RMWI, funcInst);
580         } else if (AtomicCmpXchgInst *CASI = dyn_cast<AtomicCmpXchgInst>(I)) {
581                 IRBuilder<> IRB(CASI);
582
583                 Value *Addr = CASI->getPointerOperand();
584                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
585                 if (Idx < 0)
586                         return false;
587
588                 const unsigned ByteSize = 1U << Idx;
589                 const unsigned BitSize = ByteSize * 8;
590                 Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
591                 Type *PtrTy = Ty->getPointerTo();
592
593                 Value *CmpOperand = IRB.CreateBitOrPointerCast(CASI->getCompareOperand(), Ty);
594                 Value *NewOperand = IRB.CreateBitOrPointerCast(CASI->getNewValOperand(), Ty);
595
596                 int atomic_order_index_succ = getAtomicOrderIndex(CASI->getSuccessOrdering());
597                 int atomic_order_index_fail = getAtomicOrderIndex(CASI->getFailureOrdering());
598                 Value *order_succ = ConstantInt::get(OrdTy, atomic_order_index_succ);
599                 Value *order_fail = ConstantInt::get(OrdTy, atomic_order_index_fail);
600
601                 Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy),
602                                                  CmpOperand, NewOperand,
603                                                  order_succ, order_fail, position};
604
605                 CallInst *funcInst = IRB.CreateCall(CDSAtomicCAS_V1[Idx], Args);
606                 Value *Success = IRB.CreateICmpEQ(funcInst, CmpOperand);
607
608                 Value *OldVal = funcInst;
609                 Type *OrigOldValTy = CASI->getNewValOperand()->getType();
610                 if (Ty != OrigOldValTy) {
611                         // The value is a pointer, so we need to cast the return value.
612                         OldVal = IRB.CreateIntToPtr(funcInst, OrigOldValTy);
613                 }
614
615                 Value *Res =
616                   IRB.CreateInsertValue(UndefValue::get(CASI->getType()), OldVal, 0);
617                 Res = IRB.CreateInsertValue(Res, Success, 1);
618
619                 I->replaceAllUsesWith(Res);
620                 I->eraseFromParent();
621         } else if (FenceInst *FI = dyn_cast<FenceInst>(I)) {
622                 int atomic_order_index = getAtomicOrderIndex(FI->getOrdering());
623                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
624                 Value *Args[] = {order, position};
625
626                 CallInst *funcInst = CallInst::Create(CDSAtomicThreadFence, Args);
627                 ReplaceInstWithInst(FI, funcInst);
628                 // errs() << "Thread Fences replaced\n";
629         }
630         return true;
631 }
632
633 bool CDSPass::isAtomicCall(Instruction *I) {
634         if ( auto *CI = dyn_cast<CallInst>(I) ) {
635                 Function *fun = CI->getCalledFunction();
636                 if (fun == NULL)
637                         return false;
638
639                 StringRef funName = fun->getName();
640
641                 // todo: come up with better rules for function name checking
642                 for (StringRef name : AtomicFuncNames) {
643                         if ( funName.contains(name) ) 
644                                 return true;
645                 }
646                 
647                 for (StringRef PartialName : PartialAtomicFuncNames) {
648                         if (funName.contains(PartialName) && 
649                                         funName.contains("atomic") )
650                                 return true;
651                 }
652         }
653
654         return false;
655 }
656
657 bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
658         IRBuilder<> IRB(CI);
659         Function *fun = CI->getCalledFunction();
660         StringRef funName = fun->getName();
661         std::vector<Value *> parameters;
662
663         User::op_iterator begin = CI->arg_begin();
664         User::op_iterator end = CI->arg_end();
665         for (User::op_iterator it = begin; it != end; ++it) {
666                 Value *param = *it;
667                 parameters.push_back(param);
668         }
669
670         // obtain source line number of the CallInst
671         Value *position = getPosition(CI, IRB);
672
673         // the pointer to the address is always the first argument
674         Value *OrigPtr = parameters[0];
675
676         int Idx = getMemoryAccessFuncIndex(OrigPtr, DL);
677         if (Idx < 0)
678                 return false;
679
680         const unsigned ByteSize = 1U << Idx;
681         const unsigned BitSize = ByteSize * 8;
682         Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
683         Type *PtrTy = Ty->getPointerTo();
684
685         // atomic_init; args = {obj, order}
686         if (funName.contains("atomic_init")) {
687                 Value *OrigVal = parameters[1];
688
689                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
690                 Value *val;
691                 if (OrigVal->getType()->isPtrOrPtrVectorTy())
692                         val = IRB.CreatePointerCast(OrigVal, Ty);
693                 else
694                         val = IRB.CreateIntCast(OrigVal, Ty, true);
695
696                 Value *args[] = {ptr, val, position};
697
698                 Instruction* funcInst = CallInst::Create(CDSAtomicInit[Idx], args);
699                 ReplaceInstWithInst(CI, funcInst);
700
701                 return true;
702         }
703
704         // atomic_load; args = {obj, order}
705         if (funName.contains("atomic_load")) {
706                 bool isExplicit = funName.contains("atomic_load_explicit");
707
708                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
709                 Value *order;
710                 if (isExplicit)
711                         order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
712                 else 
713                         order = ConstantInt::get(OrdTy, 
714                                                         (int) AtomicOrderingCABI::seq_cst);
715                 Value *args[] = {ptr, order, position};
716                 
717                 Instruction* funcInst = CallInst::Create(CDSAtomicLoad[Idx], args);
718                 ReplaceInstWithInst(CI, funcInst);
719
720                 return true;
721         } else if (funName.contains("atomic") && 
722                                         funName.contains("load") ) {
723                 // does this version of call always have an atomic order as an argument?
724                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
725                 Value *order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
726                 Value *args[] = {ptr, order, position};
727
728                 if (!CI->getType()->isPointerTy()) {
729                         return false;   
730                 } 
731
732                 CallInst *funcInst = IRB.CreateCall(CDSAtomicLoad[Idx], args);
733                 Value *RetVal = IRB.CreateIntToPtr(funcInst, CI->getType());
734
735                 CI->replaceAllUsesWith(RetVal);
736                 CI->eraseFromParent();
737
738                 return true;
739         }
740
741         // atomic_store; args = {obj, val, order}
742         if (funName.contains("atomic_store")) {
743                 bool isExplicit = funName.contains("atomic_store_explicit");
744                 Value *OrigVal = parameters[1];
745
746                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
747                 Value *val = IRB.CreatePointerCast(OrigVal, Ty);
748                 Value *order;
749                 if (isExplicit)
750                         order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
751                 else 
752                         order = ConstantInt::get(OrdTy, 
753                                                         (int) AtomicOrderingCABI::seq_cst);
754                 Value *args[] = {ptr, val, order, position};
755                 
756                 Instruction* funcInst = CallInst::Create(CDSAtomicStore[Idx], args);
757                 ReplaceInstWithInst(CI, funcInst);
758
759                 return true;
760         } else if (funName.contains("atomic") && 
761                                         funName.contains("store") ) {
762                 // does this version of call always have an atomic order as an argument?
763                 Value *OrigVal = parameters[1];
764
765                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
766                 Value *val;
767                 if (OrigVal->getType()->isPtrOrPtrVectorTy())
768                         val = IRB.CreatePointerCast(OrigVal, Ty);
769                 else
770                         val = IRB.CreateIntCast(OrigVal, Ty, true);
771
772                 Value *order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
773                 Value *args[] = {ptr, val, order, position};
774
775                 Instruction* funcInst = CallInst::Create(CDSAtomicStore[Idx], args);
776                 ReplaceInstWithInst(CI, funcInst);
777
778                 return true;
779         }
780
781         // atomic_fetch_*; args = {obj, val, order}
782         if (funName.contains("atomic_fetch_") || 
783                 funName.contains("atomic_exchange")) {
784
785                 /* TODO: implement stricter function name checking */
786                 if (funName.contains("non"))
787                         return false;
788
789                 bool isExplicit = funName.contains("_explicit");
790                 Value *OrigVal = parameters[1];
791
792                 int op;
793                 if ( funName.contains("_fetch_add") )
794                         op = AtomicRMWInst::Add;
795                 else if ( funName.contains("_fetch_sub") )
796                         op = AtomicRMWInst::Sub;
797                 else if ( funName.contains("_fetch_and") )
798                         op = AtomicRMWInst::And;
799                 else if ( funName.contains("_fetch_or") )
800                         op = AtomicRMWInst::Or;
801                 else if ( funName.contains("_fetch_xor") )
802                         op = AtomicRMWInst::Xor;
803                 else if ( funName.contains("atomic_exchange") )
804                         op = AtomicRMWInst::Xchg;
805                 else {
806                         errs() << "Unknown atomic read-modify-write operation\n";
807                         return false;
808                 }
809
810                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
811                 Value *val;
812                 if (OrigVal->getType()->isPtrOrPtrVectorTy())
813                         val = IRB.CreatePointerCast(OrigVal, Ty);
814                 else
815                         val = IRB.CreateIntCast(OrigVal, Ty, true);
816
817                 Value *order;
818                 if (isExplicit)
819                         order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
820                 else 
821                         order = ConstantInt::get(OrdTy, 
822                                                         (int) AtomicOrderingCABI::seq_cst);
823                 Value *args[] = {ptr, val, order, position};
824                 
825                 Instruction* funcInst = CallInst::Create(CDSAtomicRMW[op][Idx], args);
826                 ReplaceInstWithInst(CI, funcInst);
827
828                 return true;
829         } else if (funName.contains("fetch")) {
830                 errs() << "atomic exchange captured. Not implemented yet. ";
831                 errs() << "See source file :";
832                 getPosition(CI, IRB, true);
833         } else if (funName.contains("exchange") &&
834                         !funName.contains("compare_exchange") ) {
835                 errs() << "atomic exchange captured. Not implemented yet. ";
836                 errs() << "See source file :";
837                 getPosition(CI, IRB, true);
838         }
839
840         /* atomic_compare_exchange_*; 
841            args = {obj, expected, new value, order1, order2}
842         */
843         if ( funName.contains("atomic_compare_exchange_") ) {
844                 bool isExplicit = funName.contains("_explicit");
845
846                 Value *Addr = IRB.CreatePointerCast(OrigPtr, PtrTy);
847                 Value *CmpOperand = IRB.CreatePointerCast(parameters[1], PtrTy);
848                 Value *NewOperand = IRB.CreateBitOrPointerCast(parameters[2], Ty);
849
850                 Value *order_succ, *order_fail;
851                 if (isExplicit) {
852                         order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
853                         order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
854                 } else  {
855                         order_succ = ConstantInt::get(OrdTy, 
856                                                         (int) AtomicOrderingCABI::seq_cst);
857                         order_fail = ConstantInt::get(OrdTy, 
858                                                         (int) AtomicOrderingCABI::seq_cst);
859                 }
860
861                 Value *args[] = {Addr, CmpOperand, NewOperand, 
862                                                         order_succ, order_fail, position};
863                 
864                 Instruction* funcInst = CallInst::Create(CDSAtomicCAS_V2[Idx], args);
865                 ReplaceInstWithInst(CI, funcInst);
866
867                 return true;
868         } else if ( funName.contains("compare_exchange_strong") ||
869                                 funName.contains("compare_exchange_weak") ) {
870                 Value *Addr = IRB.CreatePointerCast(OrigPtr, PtrTy);
871                 Value *CmpOperand = IRB.CreatePointerCast(parameters[1], PtrTy);
872                 Value *NewOperand = IRB.CreateBitOrPointerCast(parameters[2], Ty);
873
874                 Value *order_succ, *order_fail;
875                 order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
876                 order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
877
878                 Value *args[] = {Addr, CmpOperand, NewOperand, 
879                                                         order_succ, order_fail, position};
880                 Instruction* funcInst = CallInst::Create(CDSAtomicCAS_V2[Idx], args);
881                 ReplaceInstWithInst(CI, funcInst);
882
883                 return true;
884         }
885
886         return false;
887 }
888
889 int CDSPass::getMemoryAccessFuncIndex(Value *Addr,
890                                                                                 const DataLayout &DL) {
891         Type *OrigPtrTy = Addr->getType();
892         Type *OrigTy = cast<PointerType>(OrigPtrTy)->getElementType();
893         assert(OrigTy->isSized());
894         uint32_t TypeSize = DL.getTypeStoreSizeInBits(OrigTy);
895         if (TypeSize != 8  && TypeSize != 16 &&
896                 TypeSize != 32 && TypeSize != 64 && TypeSize != 128) {
897                 NumAccessesWithBadSize++;
898                 // Ignore all unusual sizes.
899                 return -1;
900         }
901         size_t Idx = countTrailingZeros(TypeSize / 8);
902         //assert(Idx < kNumberOfAccessSizes);
903         if (Idx >= kNumberOfAccessSizes) {
904                 return -1;
905         }
906         return Idx;
907 }
908
909
910 char CDSPass::ID = 0;
911
912 // Automatically enable the pass.
913 static void registerCDSPass(const PassManagerBuilder &,
914                                                         legacy::PassManagerBase &PM) {
915         PM.add(new CDSPass());
916 }
917
918 /* Enable the pass when opt level is greater than 0 */
919 static RegisterStandardPasses 
920         RegisterMyPass1(PassManagerBuilder::EP_OptimizerLast,
921 registerCDSPass);
922
923 /* Enable the pass when opt level is 0 */
924 static RegisterStandardPasses 
925         RegisterMyPass2(PassManagerBuilder::EP_EnabledOnOptLevel0,
926 registerCDSPass);