Compare exchange strong / weak may be called without the failure memory order
[c11llvm.git] / CDSPass.cpp
1 //===-- CDSPass.cpp - xxx -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 // This file is distributed under the University of Illinois Open Source
7 // License. See LICENSE.TXT for details.
8 //
9 //===----------------------------------------------------------------------===//
10 //
11 // This file is a modified version of ThreadSanitizer.cpp, a part of a race detector.
12 //
13 // The tool is under development, for the details about previous versions see
14 // http://code.google.com/p/data-race-test
15 //
16 // The instrumentation phase is quite simple:
17 //   - Insert calls to run-time library before every memory access.
18 //      - Optimizations may apply to avoid instrumenting some of the accesses.
19 //   - Insert calls at function entry/exit.
20 // The rest is handled by the run-time library.
21 //===----------------------------------------------------------------------===//
22
23 #include "llvm/ADT/Statistic.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/Analysis/ValueTracking.h"
27 #include "llvm/Analysis/CaptureTracking.h"
28 #include "llvm/IR/BasicBlock.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/LegacyPassManager.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/PassManager.h"
36 #include "llvm/Pass.h"
37 #include "llvm/ProfileData/InstrProf.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include "llvm/Support/AtomicOrdering.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Transforms/Scalar.h"
42 #include "llvm/Transforms/Utils/Local.h"
43 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
44 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
45 #include "llvm/Transforms/Utils/EscapeEnumerator.h"
46 #include <vector>
47
48 using namespace llvm;
49
50 #define DEBUG_TYPE "CDS"
51 #include <llvm/IR/DebugLoc.h>
52
53 Value *getPosition( Instruction * I, IRBuilder <> IRB, bool print = false)
54 {
55         const DebugLoc & debug_location = I->getDebugLoc ();
56         std::string position_string;
57         {
58                 llvm::raw_string_ostream position_stream (position_string);
59                 debug_location . print (position_stream);
60         }
61
62         if (print) {
63                 errs() << position_string << "\n";
64         }
65
66         return IRB.CreateGlobalStringPtr (position_string);
67 }
68
69 STATISTIC(NumInstrumentedReads, "Number of instrumented reads");
70 STATISTIC(NumInstrumentedWrites, "Number of instrumented writes");
71 STATISTIC(NumAccessesWithBadSize, "Number of accesses with bad size");
72 // STATISTIC(NumInstrumentedVtableWrites, "Number of vtable ptr writes");
73 // STATISTIC(NumInstrumentedVtableReads, "Number of vtable ptr reads");
74
75 STATISTIC(NumOmittedReadsBeforeWrite,
76           "Number of reads ignored due to following writes");
77 STATISTIC(NumOmittedReadsFromConstantGlobals,
78           "Number of reads from constant globals");
79 STATISTIC(NumOmittedReadsFromVtable, "Number of vtable reads");
80 STATISTIC(NumOmittedNonCaptured, "Number of accesses ignored due to capturing");
81
82 Type * OrdTy;
83
84 Type * Int8PtrTy;
85 Type * Int16PtrTy;
86 Type * Int32PtrTy;
87 Type * Int64PtrTy;
88
89 Type * VoidTy;
90
91 static const size_t kNumberOfAccessSizes = 4;
92
93 int getAtomicOrderIndex(AtomicOrdering order) {
94         switch (order) {
95                 case AtomicOrdering::Monotonic: 
96                         return (int)AtomicOrderingCABI::relaxed;
97                 //  case AtomicOrdering::Consume:         // not specified yet
98                 //    return AtomicOrderingCABI::consume;
99                 case AtomicOrdering::Acquire: 
100                         return (int)AtomicOrderingCABI::acquire;
101                 case AtomicOrdering::Release: 
102                         return (int)AtomicOrderingCABI::release;
103                 case AtomicOrdering::AcquireRelease: 
104                         return (int)AtomicOrderingCABI::acq_rel;
105                 case AtomicOrdering::SequentiallyConsistent: 
106                         return (int)AtomicOrderingCABI::seq_cst;
107                 default:
108                         // unordered or Not Atomic
109                         return -1;
110         }
111 }
112
113 AtomicOrderingCABI indexToAtomicOrder(int index) {
114         switch (index) {
115                 case 0:
116                         return AtomicOrderingCABI::relaxed;
117                 case 1:
118                         return AtomicOrderingCABI::consume;
119                 case 2:
120                         return AtomicOrderingCABI::acquire;
121                 case 3:
122                         return AtomicOrderingCABI::release;
123                 case 4:
124                         return AtomicOrderingCABI::acq_rel;
125                 case 5:
126                         return AtomicOrderingCABI::seq_cst;
127                 default:
128                         errs() << "Bad Atomic index\n";
129                         return AtomicOrderingCABI::seq_cst;
130         }
131 }
132
133 /* According to atomic_base.h: __cmpexch_failure_order */
134 int AtomicCasFailureOrderIndex(int index) {
135         AtomicOrderingCABI succ_order = indexToAtomicOrder(index);
136         AtomicOrderingCABI fail_order;
137         if (succ_order == AtomicOrderingCABI::acq_rel)
138                 fail_order = AtomicOrderingCABI::acquire;
139         else if (succ_order == AtomicOrderingCABI::release) 
140                 fail_order = AtomicOrderingCABI::relaxed;
141         else
142                 fail_order = succ_order;
143
144         return (int) fail_order;
145 }
146
147 namespace {
148         struct CDSPass : public FunctionPass {
149                 static char ID;
150                 CDSPass() : FunctionPass(ID) {}
151                 bool runOnFunction(Function &F) override; 
152                 StringRef getPassName() const override;
153
154         private:
155                 void initializeCallbacks(Module &M);
156                 bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL);
157                 bool instrumentVolatile(Instruction *I, const DataLayout &DL);
158                 bool isAtomicCall(Instruction *I);
159                 bool instrumentAtomic(Instruction *I, const DataLayout &DL);
160                 bool instrumentAtomicCall(CallInst *CI, const DataLayout &DL);
161                 void chooseInstructionsToInstrument(SmallVectorImpl<Instruction *> &Local,
162                                                                                         SmallVectorImpl<Instruction *> &All,
163                                                                                         const DataLayout &DL);
164                 bool addrPointsToConstantData(Value *Addr);
165                 int getMemoryAccessFuncIndex(Value *Addr, const DataLayout &DL);
166
167                 // Callbacks to run-time library are computed in doInitialization.
168                 Constant * CDSFuncEntry;
169                 Constant * CDSFuncExit;
170
171                 Constant * CDSLoad[kNumberOfAccessSizes];
172                 Constant * CDSStore[kNumberOfAccessSizes];
173                 Constant * CDSVolatileLoad[kNumberOfAccessSizes];
174                 Constant * CDSVolatileStore[kNumberOfAccessSizes];
175                 Constant * CDSAtomicInit[kNumberOfAccessSizes];
176                 Constant * CDSAtomicLoad[kNumberOfAccessSizes];
177                 Constant * CDSAtomicStore[kNumberOfAccessSizes];
178                 Constant * CDSAtomicRMW[AtomicRMWInst::LAST_BINOP + 1][kNumberOfAccessSizes];
179                 Constant * CDSAtomicCAS_V1[kNumberOfAccessSizes];
180                 Constant * CDSAtomicCAS_V2[kNumberOfAccessSizes];
181                 Constant * CDSAtomicThreadFence;
182
183                 std::vector<StringRef> AtomicFuncNames;
184                 std::vector<StringRef> PartialAtomicFuncNames;
185         };
186 }
187
188 StringRef CDSPass::getPassName() const {
189         return "CDSPass";
190 }
191
192 static bool isVtableAccess(Instruction *I) {
193         if (MDNode *Tag = I->getMetadata(LLVMContext::MD_tbaa))
194                 return Tag->isTBAAVtableAccess();
195         return false;
196 }
197
198 void CDSPass::initializeCallbacks(Module &M) {
199         LLVMContext &Ctx = M.getContext();
200
201         Type * Int1Ty = Type::getInt1Ty(Ctx);
202         OrdTy = Type::getInt32Ty(Ctx);
203
204         Int8PtrTy  = Type::getInt8PtrTy(Ctx);
205         Int16PtrTy = Type::getInt16PtrTy(Ctx);
206         Int32PtrTy = Type::getInt32PtrTy(Ctx);
207         Int64PtrTy = Type::getInt64PtrTy(Ctx);
208
209         VoidTy = Type::getVoidTy(Ctx);
210
211         CDSFuncEntry = M.getOrInsertFunction("cds_func_entry", 
212                                                                 VoidTy, Int8PtrTy);
213         CDSFuncExit = M.getOrInsertFunction("cds_func_exit", 
214                                                                 VoidTy, Int8PtrTy);
215
216         // Get the function to call from our untime library.
217         for (unsigned i = 0; i < kNumberOfAccessSizes; i++) {
218                 const unsigned ByteSize = 1U << i;
219                 const unsigned BitSize = ByteSize * 8;
220
221                 std::string ByteSizeStr = utostr(ByteSize);
222                 std::string BitSizeStr = utostr(BitSize);
223
224                 Type *Ty = Type::getIntNTy(Ctx, BitSize);
225                 Type *PtrTy = Ty->getPointerTo();
226
227                 // uint8_t cds_atomic_load8 (void * obj, int atomic_index)
228                 // void cds_atomic_store8 (void * obj, int atomic_index, uint8_t val)
229                 SmallString<32> LoadName("cds_load" + BitSizeStr);
230                 SmallString<32> StoreName("cds_store" + BitSizeStr);
231                 SmallString<32> VolatileLoadName("cds_volatile_load" + BitSizeStr);
232                 SmallString<32> VolatileStoreName("cds_volatile_store" + BitSizeStr);
233                 SmallString<32> AtomicInitName("cds_atomic_init" + BitSizeStr);
234                 SmallString<32> AtomicLoadName("cds_atomic_load" + BitSizeStr);
235                 SmallString<32> AtomicStoreName("cds_atomic_store" + BitSizeStr);
236
237                 CDSLoad[i]  = M.getOrInsertFunction(LoadName, VoidTy, PtrTy);
238                 CDSStore[i] = M.getOrInsertFunction(StoreName, VoidTy, PtrTy);
239                 CDSVolatileLoad[i]  = M.getOrInsertFunction(VolatileLoadName,
240                                                                         Ty, PtrTy, Int8PtrTy);
241                 CDSVolatileStore[i] = M.getOrInsertFunction(VolatileStoreName, 
242                                                                         VoidTy, PtrTy, Ty, Int8PtrTy);
243                 CDSAtomicInit[i] = M.getOrInsertFunction(AtomicInitName, 
244                                                                 VoidTy, PtrTy, Ty, Int8PtrTy);
245                 CDSAtomicLoad[i]  = M.getOrInsertFunction(AtomicLoadName, 
246                                                                 Ty, PtrTy, OrdTy, Int8PtrTy);
247                 CDSAtomicStore[i] = M.getOrInsertFunction(AtomicStoreName, 
248                                                                 VoidTy, PtrTy, Ty, OrdTy, Int8PtrTy);
249
250                 for (int op = AtomicRMWInst::FIRST_BINOP; 
251                         op <= AtomicRMWInst::LAST_BINOP; ++op) {
252                         CDSAtomicRMW[op][i] = nullptr;
253                         std::string NamePart;
254
255                         if (op == AtomicRMWInst::Xchg)
256                                 NamePart = "_exchange";
257                         else if (op == AtomicRMWInst::Add) 
258                                 NamePart = "_fetch_add";
259                         else if (op == AtomicRMWInst::Sub)
260                                 NamePart = "_fetch_sub";
261                         else if (op == AtomicRMWInst::And)
262                                 NamePart = "_fetch_and";
263                         else if (op == AtomicRMWInst::Or)
264                                 NamePart = "_fetch_or";
265                         else if (op == AtomicRMWInst::Xor)
266                                 NamePart = "_fetch_xor";
267                         else
268                                 continue;
269
270                         SmallString<32> AtomicRMWName("cds_atomic" + NamePart + BitSizeStr);
271                         CDSAtomicRMW[op][i] = M.getOrInsertFunction(AtomicRMWName, 
272                                                                                 Ty, PtrTy, Ty, OrdTy, Int8PtrTy);
273                 }
274
275                 // only supportes strong version
276                 SmallString<32> AtomicCASName_V1("cds_atomic_compare_exchange" + BitSizeStr + "_v1");
277                 SmallString<32> AtomicCASName_V2("cds_atomic_compare_exchange" + BitSizeStr + "_v2");
278                 CDSAtomicCAS_V1[i] = M.getOrInsertFunction(AtomicCASName_V1, 
279                                                                 Ty, PtrTy, Ty, Ty, OrdTy, OrdTy, Int8PtrTy);
280                 CDSAtomicCAS_V2[i] = M.getOrInsertFunction(AtomicCASName_V2, 
281                                                                 Int1Ty, PtrTy, PtrTy, Ty, OrdTy, OrdTy, Int8PtrTy);
282         }
283
284         CDSAtomicThreadFence = M.getOrInsertFunction("cds_atomic_thread_fence", 
285                                                                                                         VoidTy, OrdTy, Int8PtrTy);
286 }
287
288 static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) {
289         // Peel off GEPs and BitCasts.
290         Addr = Addr->stripInBoundsOffsets();
291
292         if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
293                 if (GV->hasSection()) {
294                         StringRef SectionName = GV->getSection();
295                         // Check if the global is in the PGO counters section.
296                         auto OF = Triple(M->getTargetTriple()).getObjectFormat();
297                         if (SectionName.endswith(
298                               getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false)))
299                                 return false;
300                 }
301
302                 // Check if the global is private gcov data.
303                 if (GV->getName().startswith("__llvm_gcov") ||
304                 GV->getName().startswith("__llvm_gcda"))
305                 return false;
306         }
307
308         // Do not instrument acesses from different address spaces; we cannot deal
309         // with them.
310         if (Addr) {
311                 Type *PtrTy = cast<PointerType>(Addr->getType()->getScalarType());
312                 if (PtrTy->getPointerAddressSpace() != 0)
313                         return false;
314         }
315
316         return true;
317 }
318
319 bool CDSPass::addrPointsToConstantData(Value *Addr) {
320         // If this is a GEP, just analyze its pointer operand.
321         if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Addr))
322                 Addr = GEP->getPointerOperand();
323
324         if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
325                 if (GV->isConstant()) {
326                         // Reads from constant globals can not race with any writes.
327                         NumOmittedReadsFromConstantGlobals++;
328                         return true;
329                 }
330         } else if (LoadInst *L = dyn_cast<LoadInst>(Addr)) {
331                 if (isVtableAccess(L)) {
332                         // Reads from a vtable pointer can not race with any writes.
333                         NumOmittedReadsFromVtable++;
334                         return true;
335                 }
336         }
337         return false;
338 }
339
340 bool CDSPass::runOnFunction(Function &F) {
341         if (F.getName() == "main") {
342                 F.setName("user_main");
343                 errs() << "main replaced by user_main\n";
344         }
345
346         if (true) {
347                 initializeCallbacks( *F.getParent() );
348
349                 AtomicFuncNames = 
350                 {
351                         "atomic_init", "atomic_load", "atomic_store", 
352                         "atomic_fetch_", "atomic_exchange", "atomic_compare_exchange_"
353                 };
354
355                 PartialAtomicFuncNames = 
356                 { 
357                         "load", "store", "fetch", "exchange", "compare_exchange_" 
358                 };
359
360                 SmallVector<Instruction*, 8> AllLoadsAndStores;
361                 SmallVector<Instruction*, 8> LocalLoadsAndStores;
362                 SmallVector<Instruction*, 8> VolatileLoadsAndStores;
363                 SmallVector<Instruction*, 8> AtomicAccesses;
364
365                 std::vector<Instruction *> worklist;
366
367                 bool Res = false;
368                 bool HasAtomic = false;
369                 bool HasVolatile = false;
370                 const DataLayout &DL = F.getParent()->getDataLayout();
371
372                 // errs() << "--- " << F.getName() << "---\n";
373
374                 for (auto &B : F) {
375                         for (auto &I : B) {
376                                 if ( (&I)->isAtomic() || isAtomicCall(&I) ) {
377                                         AtomicAccesses.push_back(&I);
378                                         HasAtomic = true;
379                                 } else if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
380                                         LoadInst *LI = dyn_cast<LoadInst>(&I);
381                                         StoreInst *SI = dyn_cast<StoreInst>(&I);
382                                         bool isVolatile = ( LI ? LI->isVolatile() : SI->isVolatile() );
383
384                                         if (isVolatile) {
385                                                 VolatileLoadsAndStores.push_back(&I);
386                                                 HasVolatile = true;
387                                         } else
388                                                 LocalLoadsAndStores.push_back(&I);
389                                 } else if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
390                                         // not implemented yet
391                                 }
392                         }
393
394                         chooseInstructionsToInstrument(LocalLoadsAndStores, AllLoadsAndStores, DL);
395                 }
396
397                 for (auto Inst : AllLoadsAndStores) {
398                         Res |= instrumentLoadOrStore(Inst, DL);
399                 }
400
401                 for (auto Inst : VolatileLoadsAndStores) {
402                         Res |= instrumentVolatile(Inst, DL);
403                 }
404
405                 for (auto Inst : AtomicAccesses) {
406                         Res |= instrumentAtomic(Inst, DL);
407                 }
408
409                 // only instrument functions that contain atomics
410                 if (Res && ( HasAtomic || HasVolatile) ) {
411                         IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI());
412                         /* Unused for now
413                         Value *ReturnAddress = IRB.CreateCall(
414                                 Intrinsic::getDeclaration(F.getParent(), Intrinsic::returnaddress),
415                                 IRB.getInt32(0));
416                         */
417
418                         Value * FuncName = IRB.CreateGlobalStringPtr(F.getName());
419                         IRB.CreateCall(CDSFuncEntry, FuncName);
420
421                         EscapeEnumerator EE(F, "cds_cleanup", true);
422                         while (IRBuilder<> *AtExit = EE.Next()) {
423                           AtExit->CreateCall(CDSFuncExit, FuncName);
424                         }
425
426                         Res = true;
427                 }
428         }
429
430         return false;
431 }
432
433 void CDSPass::chooseInstructionsToInstrument(
434         SmallVectorImpl<Instruction *> &Local, SmallVectorImpl<Instruction *> &All,
435         const DataLayout &DL) {
436         SmallPtrSet<Value*, 8> WriteTargets;
437         // Iterate from the end.
438         for (Instruction *I : reverse(Local)) {
439                 if (StoreInst *Store = dyn_cast<StoreInst>(I)) {
440                         Value *Addr = Store->getPointerOperand();
441                         if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
442                                 continue;
443                         WriteTargets.insert(Addr);
444                 } else {
445                         LoadInst *Load = cast<LoadInst>(I);
446                         Value *Addr = Load->getPointerOperand();
447                         if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
448                                 continue;
449                         if (WriteTargets.count(Addr)) {
450                                 // We will write to this temp, so no reason to analyze the read.
451                                 NumOmittedReadsBeforeWrite++;
452                                 continue;
453                         }
454                         if (addrPointsToConstantData(Addr)) {
455                                 // Addr points to some constant data -- it can not race with any writes.
456                                 continue;
457                         }
458                 }
459                 Value *Addr = isa<StoreInst>(*I)
460                         ? cast<StoreInst>(I)->getPointerOperand()
461                         : cast<LoadInst>(I)->getPointerOperand();
462                 if (isa<AllocaInst>(GetUnderlyingObject(Addr, DL)) &&
463                                 !PointerMayBeCaptured(Addr, true, true)) {
464                         // The variable is addressable but not captured, so it cannot be
465                         // referenced from a different thread and participate in a data race
466                         // (see llvm/Analysis/CaptureTracking.h for details).
467                         NumOmittedNonCaptured++;
468                         continue;
469                 }
470                 All.push_back(I);
471         }
472         Local.clear();
473 }
474
475
476 bool CDSPass::instrumentLoadOrStore(Instruction *I,
477                                                                         const DataLayout &DL) {
478         IRBuilder<> IRB(I);
479         bool IsWrite = isa<StoreInst>(*I);
480         Value *Addr = IsWrite
481                 ? cast<StoreInst>(I)->getPointerOperand()
482                 : cast<LoadInst>(I)->getPointerOperand();
483
484         // swifterror memory addresses are mem2reg promoted by instruction selection.
485         // As such they cannot have regular uses like an instrumentation function and
486         // it makes no sense to track them as memory.
487         if (Addr->isSwiftError())
488         return false;
489
490         int Idx = getMemoryAccessFuncIndex(Addr, DL);
491         if (Idx < 0)
492                 return false;
493
494 //  not supported by CDS yet
495 /*  if (IsWrite && isVtableAccess(I)) {
496     LLVM_DEBUG(dbgs() << "  VPTR : " << *I << "\n");
497     Value *StoredValue = cast<StoreInst>(I)->getValueOperand();
498     // StoredValue may be a vector type if we are storing several vptrs at once.
499     // In this case, just take the first element of the vector since this is
500     // enough to find vptr races.
501     if (isa<VectorType>(StoredValue->getType()))
502       StoredValue = IRB.CreateExtractElement(
503           StoredValue, ConstantInt::get(IRB.getInt32Ty(), 0));
504     if (StoredValue->getType()->isIntegerTy())
505       StoredValue = IRB.CreateIntToPtr(StoredValue, IRB.getInt8PtrTy());
506     // Call TsanVptrUpdate.
507     IRB.CreateCall(TsanVptrUpdate,
508                    {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()),
509                     IRB.CreatePointerCast(StoredValue, IRB.getInt8PtrTy())});
510     NumInstrumentedVtableWrites++;
511     return true;
512   }
513
514   if (!IsWrite && isVtableAccess(I)) {
515     IRB.CreateCall(TsanVptrLoad,
516                    IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()));
517     NumInstrumentedVtableReads++;
518     return true;
519   }
520 */
521
522         Value *OnAccessFunc = nullptr;
523         OnAccessFunc = IsWrite ? CDSStore[Idx] : CDSLoad[Idx];
524
525         Type *ArgType = IRB.CreatePointerCast(Addr, Addr->getType())->getType();
526
527         if ( ArgType != Int8PtrTy && ArgType != Int16PtrTy && 
528                         ArgType != Int32PtrTy && ArgType != Int64PtrTy ) {
529                 // if other types of load or stores are passed in
530                 return false;   
531         }
532         IRB.CreateCall(OnAccessFunc, IRB.CreatePointerCast(Addr, Addr->getType()));
533         if (IsWrite) NumInstrumentedWrites++;
534         else         NumInstrumentedReads++;
535         return true;
536 }
537
538 bool CDSPass::instrumentVolatile(Instruction * I, const DataLayout &DL) {
539         IRBuilder<> IRB(I);
540         Value *position = getPosition(I, IRB);
541
542         if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
543                 assert( LI->isVolatile() );
544                 Value *Addr = LI->getPointerOperand();
545                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
546                 if (Idx < 0)
547                         return false;
548
549                 Value *args[] = {Addr, position};
550                 Instruction* funcInst=CallInst::Create(CDSVolatileLoad[Idx], args);
551                 ReplaceInstWithInst(LI, funcInst);
552         } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
553                 assert( SI->isVolatile() );
554                 Value *Addr = SI->getPointerOperand();
555                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
556                 if (Idx < 0)
557                         return false;
558
559                 Value *val = SI->getValueOperand();
560                 Value *args[] = {Addr, val, position};
561                 Instruction* funcInst=CallInst::Create(CDSVolatileStore[Idx], args);
562                 ReplaceInstWithInst(SI, funcInst);
563         } else {
564                 return false;
565         }
566
567         return true;
568 }
569
570 bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
571         IRBuilder<> IRB(I);
572
573         if (auto *CI = dyn_cast<CallInst>(I)) {
574                 return instrumentAtomicCall(CI, DL);
575         }
576
577         Value *position = getPosition(I, IRB);
578
579         if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
580                 Value *Addr = LI->getPointerOperand();
581                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
582                 if (Idx < 0)
583                         return false;
584
585                 int atomic_order_index = getAtomicOrderIndex(LI->getOrdering());
586                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
587                 Value *args[] = {Addr, order, position};
588                 Instruction* funcInst=CallInst::Create(CDSAtomicLoad[Idx], args);
589                 ReplaceInstWithInst(LI, funcInst);
590         } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
591                 Value *Addr = SI->getPointerOperand();
592                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
593                 if (Idx < 0)
594                         return false;
595
596                 int atomic_order_index = getAtomicOrderIndex(SI->getOrdering());
597                 Value *val = SI->getValueOperand();
598                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
599                 Value *args[] = {Addr, val, order, position};
600                 Instruction* funcInst=CallInst::Create(CDSAtomicStore[Idx], args);
601                 ReplaceInstWithInst(SI, funcInst);
602         } else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) {
603                 Value *Addr = RMWI->getPointerOperand();
604                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
605                 if (Idx < 0)
606                         return false;
607
608                 int atomic_order_index = getAtomicOrderIndex(RMWI->getOrdering());
609                 Value *val = RMWI->getValOperand();
610                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
611                 Value *args[] = {Addr, val, order, position};
612                 Instruction* funcInst = CallInst::Create(CDSAtomicRMW[RMWI->getOperation()][Idx], args);
613                 ReplaceInstWithInst(RMWI, funcInst);
614         } else if (AtomicCmpXchgInst *CASI = dyn_cast<AtomicCmpXchgInst>(I)) {
615                 IRBuilder<> IRB(CASI);
616
617                 Value *Addr = CASI->getPointerOperand();
618                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
619                 if (Idx < 0)
620                         return false;
621
622                 const unsigned ByteSize = 1U << Idx;
623                 const unsigned BitSize = ByteSize * 8;
624                 Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
625                 Type *PtrTy = Ty->getPointerTo();
626
627                 Value *CmpOperand = IRB.CreateBitOrPointerCast(CASI->getCompareOperand(), Ty);
628                 Value *NewOperand = IRB.CreateBitOrPointerCast(CASI->getNewValOperand(), Ty);
629
630                 int atomic_order_index_succ = getAtomicOrderIndex(CASI->getSuccessOrdering());
631                 int atomic_order_index_fail = getAtomicOrderIndex(CASI->getFailureOrdering());
632                 Value *order_succ = ConstantInt::get(OrdTy, atomic_order_index_succ);
633                 Value *order_fail = ConstantInt::get(OrdTy, atomic_order_index_fail);
634
635                 Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy),
636                                                  CmpOperand, NewOperand,
637                                                  order_succ, order_fail, position};
638
639                 CallInst *funcInst = IRB.CreateCall(CDSAtomicCAS_V1[Idx], Args);
640                 Value *Success = IRB.CreateICmpEQ(funcInst, CmpOperand);
641
642                 Value *OldVal = funcInst;
643                 Type *OrigOldValTy = CASI->getNewValOperand()->getType();
644                 if (Ty != OrigOldValTy) {
645                         // The value is a pointer, so we need to cast the return value.
646                         OldVal = IRB.CreateIntToPtr(funcInst, OrigOldValTy);
647                 }
648
649                 Value *Res =
650                   IRB.CreateInsertValue(UndefValue::get(CASI->getType()), OldVal, 0);
651                 Res = IRB.CreateInsertValue(Res, Success, 1);
652
653                 I->replaceAllUsesWith(Res);
654                 I->eraseFromParent();
655         } else if (FenceInst *FI = dyn_cast<FenceInst>(I)) {
656                 int atomic_order_index = getAtomicOrderIndex(FI->getOrdering());
657                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
658                 Value *Args[] = {order, position};
659
660                 CallInst *funcInst = CallInst::Create(CDSAtomicThreadFence, Args);
661                 ReplaceInstWithInst(FI, funcInst);
662                 // errs() << "Thread Fences replaced\n";
663         }
664         return true;
665 }
666
667 bool CDSPass::isAtomicCall(Instruction *I) {
668         if ( auto *CI = dyn_cast<CallInst>(I) ) {
669                 Function *fun = CI->getCalledFunction();
670                 if (fun == NULL)
671                         return false;
672
673                 StringRef funName = fun->getName();
674
675                 // todo: come up with better rules for function name checking
676                 for (StringRef name : AtomicFuncNames) {
677                         if ( funName.contains(name) ) 
678                                 return true;
679                 }
680                 
681                 for (StringRef PartialName : PartialAtomicFuncNames) {
682                         if (funName.contains(PartialName) && 
683                                         funName.contains("atomic") )
684                                 return true;
685                 }
686         }
687
688         return false;
689 }
690
691 bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
692         IRBuilder<> IRB(CI);
693         Function *fun = CI->getCalledFunction();
694         StringRef funName = fun->getName();
695         std::vector<Value *> parameters;
696
697         User::op_iterator begin = CI->arg_begin();
698         User::op_iterator end = CI->arg_end();
699         for (User::op_iterator it = begin; it != end; ++it) {
700                 Value *param = *it;
701                 parameters.push_back(param);
702         }
703
704         // obtain source line number of the CallInst
705         Value *position = getPosition(CI, IRB);
706
707         // the pointer to the address is always the first argument
708         Value *OrigPtr = parameters[0];
709
710         int Idx = getMemoryAccessFuncIndex(OrigPtr, DL);
711         if (Idx < 0)
712                 return false;
713
714         const unsigned ByteSize = 1U << Idx;
715         const unsigned BitSize = ByteSize * 8;
716         Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
717         Type *PtrTy = Ty->getPointerTo();
718
719         // atomic_init; args = {obj, order}
720         if (funName.contains("atomic_init")) {
721                 Value *OrigVal = parameters[1];
722
723                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
724                 Value *val;
725                 if (OrigVal->getType()->isPtrOrPtrVectorTy())
726                         val = IRB.CreatePointerCast(OrigVal, Ty);
727                 else
728                         val = IRB.CreateIntCast(OrigVal, Ty, true);
729
730                 Value *args[] = {ptr, val, position};
731
732                 Instruction* funcInst = CallInst::Create(CDSAtomicInit[Idx], args);
733                 ReplaceInstWithInst(CI, funcInst);
734
735                 return true;
736         }
737
738         // atomic_load; args = {obj, order}
739         if (funName.contains("atomic_load")) {
740                 bool isExplicit = funName.contains("atomic_load_explicit");
741
742                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
743                 Value *order;
744                 if (isExplicit)
745                         order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
746                 else 
747                         order = ConstantInt::get(OrdTy, 
748                                                         (int) AtomicOrderingCABI::seq_cst);
749                 Value *args[] = {ptr, order, position};
750                 
751                 Instruction* funcInst = CallInst::Create(CDSAtomicLoad[Idx], args);
752                 ReplaceInstWithInst(CI, funcInst);
753
754                 return true;
755         } else if (funName.contains("atomic") && 
756                                         funName.contains("load") ) {
757                 // does this version of call always have an atomic order as an argument?
758                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
759                 Value *order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
760                 Value *args[] = {ptr, order, position};
761
762                 if (!CI->getType()->isPointerTy()) {
763                         return false;   
764                 } 
765
766                 CallInst *funcInst = IRB.CreateCall(CDSAtomicLoad[Idx], args);
767                 Value *RetVal = IRB.CreateIntToPtr(funcInst, CI->getType());
768
769                 CI->replaceAllUsesWith(RetVal);
770                 CI->eraseFromParent();
771
772                 return true;
773         }
774
775         // atomic_store; args = {obj, val, order}
776         if (funName.contains("atomic_store")) {
777                 bool isExplicit = funName.contains("atomic_store_explicit");
778                 Value *OrigVal = parameters[1];
779
780                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
781                 Value *val = IRB.CreatePointerCast(OrigVal, Ty);
782                 Value *order;
783                 if (isExplicit)
784                         order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
785                 else 
786                         order = ConstantInt::get(OrdTy, 
787                                                         (int) AtomicOrderingCABI::seq_cst);
788                 Value *args[] = {ptr, val, order, position};
789                 
790                 Instruction* funcInst = CallInst::Create(CDSAtomicStore[Idx], args);
791                 ReplaceInstWithInst(CI, funcInst);
792
793                 return true;
794         } else if (funName.contains("atomic") && 
795                                         funName.contains("store") ) {
796                 // does this version of call always have an atomic order as an argument?
797                 Value *OrigVal = parameters[1];
798
799                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
800                 Value *val;
801                 if (OrigVal->getType()->isPtrOrPtrVectorTy())
802                         val = IRB.CreatePointerCast(OrigVal, Ty);
803                 else
804                         val = IRB.CreateIntCast(OrigVal, Ty, true);
805
806                 Value *order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
807                 Value *args[] = {ptr, val, order, position};
808
809                 Instruction* funcInst = CallInst::Create(CDSAtomicStore[Idx], args);
810                 ReplaceInstWithInst(CI, funcInst);
811
812                 return true;
813         }
814
815         // atomic_fetch_*; args = {obj, val, order}
816         if (funName.contains("atomic_fetch_") || 
817                 funName.contains("atomic_exchange")) {
818
819                 /* TODO: implement stricter function name checking */
820                 if (funName.contains("non"))
821                         return false;
822
823                 bool isExplicit = funName.contains("_explicit");
824                 Value *OrigVal = parameters[1];
825
826                 int op;
827                 if ( funName.contains("_fetch_add") )
828                         op = AtomicRMWInst::Add;
829                 else if ( funName.contains("_fetch_sub") )
830                         op = AtomicRMWInst::Sub;
831                 else if ( funName.contains("_fetch_and") )
832                         op = AtomicRMWInst::And;
833                 else if ( funName.contains("_fetch_or") )
834                         op = AtomicRMWInst::Or;
835                 else if ( funName.contains("_fetch_xor") )
836                         op = AtomicRMWInst::Xor;
837                 else if ( funName.contains("atomic_exchange") )
838                         op = AtomicRMWInst::Xchg;
839                 else {
840                         errs() << "Unknown atomic read-modify-write operation\n";
841                         return false;
842                 }
843
844                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
845                 Value *val;
846                 if (OrigVal->getType()->isPtrOrPtrVectorTy())
847                         val = IRB.CreatePointerCast(OrigVal, Ty);
848                 else
849                         val = IRB.CreateIntCast(OrigVal, Ty, true);
850
851                 Value *order;
852                 if (isExplicit)
853                         order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
854                 else 
855                         order = ConstantInt::get(OrdTy, 
856                                                         (int) AtomicOrderingCABI::seq_cst);
857                 Value *args[] = {ptr, val, order, position};
858                 
859                 Instruction* funcInst = CallInst::Create(CDSAtomicRMW[op][Idx], args);
860                 ReplaceInstWithInst(CI, funcInst);
861
862                 return true;
863         } else if (funName.contains("fetch")) {
864                 errs() << "atomic fetch captured. Not implemented yet. ";
865                 errs() << "See source file :";
866                 getPosition(CI, IRB, true);
867                 return false;
868         } else if (funName.contains("exchange") &&
869                         !funName.contains("compare_exchange") ) {
870                 if (CI->getType()->isPointerTy()) {
871                         // Can not deal with this now
872                         errs() << "atomic exchange captured. Not implemented yet. ";
873                         errs() << "See source file :";
874                         getPosition(CI, IRB, true);
875
876                         return false;
877                 }
878
879                 Value *OrigVal = parameters[1];
880
881                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
882                 Value *val;
883                 if (OrigVal->getType()->isPtrOrPtrVectorTy())
884                         val = IRB.CreatePointerCast(OrigVal, Ty);
885                 else
886                         val = IRB.CreateIntCast(OrigVal, Ty, true);
887
888                 Value *order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
889                 Value *args[] = {ptr, val, order, position};
890                 int op = AtomicRMWInst::Xchg;
891                 
892                 Instruction* funcInst = CallInst::Create(CDSAtomicRMW[op][Idx], args);
893                 ReplaceInstWithInst(CI, funcInst);
894         }
895
896         /* atomic_compare_exchange_*; 
897            args = {obj, expected, new value, order1, order2}
898         */
899         if ( funName.contains("atomic_compare_exchange_") ) {
900                 bool isExplicit = funName.contains("_explicit");
901
902                 Value *Addr = IRB.CreatePointerCast(OrigPtr, PtrTy);
903                 Value *CmpOperand = IRB.CreatePointerCast(parameters[1], PtrTy);
904                 Value *NewOperand = IRB.CreateBitOrPointerCast(parameters[2], Ty);
905
906                 Value *order_succ, *order_fail;
907                 if (isExplicit) {
908                         order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
909
910                         if (parameters.size() > 4) {
911                                 order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
912                         } else {
913                                 /* The failure order is not provided */
914                                 order_fail = order_succ;
915                                 ConstantInt * order_succ_cast = dyn_cast<ConstantInt>(order_succ);
916                                 int index = order_succ_cast->getSExtValue();
917
918                                 order_fail = ConstantInt::get(OrdTy,
919                                                                 AtomicCasFailureOrderIndex(index));
920                         }
921                 } else  {
922                         order_succ = ConstantInt::get(OrdTy, 
923                                                         (int) AtomicOrderingCABI::seq_cst);
924                         order_fail = ConstantInt::get(OrdTy, 
925                                                         (int) AtomicOrderingCABI::seq_cst);
926                 }
927
928                 Value *args[] = {Addr, CmpOperand, NewOperand, 
929                                                         order_succ, order_fail, position};
930                 
931                 Instruction* funcInst = CallInst::Create(CDSAtomicCAS_V2[Idx], args);
932                 ReplaceInstWithInst(CI, funcInst);
933
934                 return true;
935         } else if ( funName.contains("compare_exchange_strong") ||
936                                 funName.contains("compare_exchange_weak") ) {
937                 Value *Addr = IRB.CreatePointerCast(OrigPtr, PtrTy);
938                 Value *CmpOperand = IRB.CreatePointerCast(parameters[1], PtrTy);
939                 Value *NewOperand = IRB.CreateBitOrPointerCast(parameters[2], Ty);
940
941                 Value *order_succ, *order_fail;
942                 order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
943
944                 if (parameters.size() > 4) {
945                         order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
946                 } else {
947                         /* The failure order is not provided */
948                         order_fail = order_succ;
949                         ConstantInt * order_succ_cast = dyn_cast<ConstantInt>(order_succ);
950                         int index = order_succ_cast->getSExtValue();
951
952                         order_fail = ConstantInt::get(OrdTy,
953                                                         AtomicCasFailureOrderIndex(index));
954                 }
955
956                 Value *args[] = {Addr, CmpOperand, NewOperand, 
957                                                         order_succ, order_fail, position};
958                 Instruction* funcInst = CallInst::Create(CDSAtomicCAS_V2[Idx], args);
959                 ReplaceInstWithInst(CI, funcInst);
960
961                 return true;
962         }
963
964         return false;
965 }
966
967 int CDSPass::getMemoryAccessFuncIndex(Value *Addr,
968                                                                                 const DataLayout &DL) {
969         Type *OrigPtrTy = Addr->getType();
970         Type *OrigTy = cast<PointerType>(OrigPtrTy)->getElementType();
971         assert(OrigTy->isSized());
972         uint32_t TypeSize = DL.getTypeStoreSizeInBits(OrigTy);
973         if (TypeSize != 8  && TypeSize != 16 &&
974                 TypeSize != 32 && TypeSize != 64 && TypeSize != 128) {
975                 NumAccessesWithBadSize++;
976                 // Ignore all unusual sizes.
977                 return -1;
978         }
979         size_t Idx = countTrailingZeros(TypeSize / 8);
980         //assert(Idx < kNumberOfAccessSizes);
981         if (Idx >= kNumberOfAccessSizes) {
982                 return -1;
983         }
984         return Idx;
985 }
986
987
988 char CDSPass::ID = 0;
989
990 // Automatically enable the pass.
991 static void registerCDSPass(const PassManagerBuilder &,
992                                                         legacy::PassManagerBase &PM) {
993         PM.add(new CDSPass());
994 }
995
996 /* Enable the pass when opt level is greater than 0 */
997 static RegisterStandardPasses 
998         RegisterMyPass1(PassManagerBuilder::EP_OptimizerLast,
999 registerCDSPass);
1000
1001 /* Enable the pass when opt level is 0 */
1002 static RegisterStandardPasses 
1003         RegisterMyPass2(PassManagerBuilder::EP_EnabledOnOptLevel0,
1004 registerCDSPass);