recent changes
[c11llvm.git] / CDSPass.cpp
1 //===-- CDSPass.cpp - xxx -------------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file is a modified version of ThreadSanitizer.cpp, a part of a race detector.
11 //
12 // The tool is under development, for the details about previous versions see
13 // http://code.google.com/p/data-race-test
14 //
15 // The instrumentation phase is quite simple:
16 //   - Insert calls to run-time library before every memory access.
17 //      - Optimizations may apply to avoid instrumenting some of the accesses.
18 //   - Insert calls at function entry/exit.
19 // The rest is handled by the run-time library.
20 //===----------------------------------------------------------------------===//
21
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/SmallString.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/Analysis/CaptureTracking.h"
27 #include "llvm/IR/BasicBlock.h"
28 #include "llvm/IR/CFG.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/LegacyPassManager.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/PassManager.h"
36 #include "llvm/Pass.h"
37 #include "llvm/ProfileData/InstrProf.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include "llvm/Support/AtomicOrdering.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Transforms/Scalar.h"
42 #include "llvm/Transforms/Utils/Local.h"
43 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
44 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
45 #include <list>
46 #include <vector>
47 // #include "llvm/Support/MathExtras.h"
48
49 #define DEBUG_TYPE "CDS"
50 using namespace llvm;
51
52 #define FUNCARRAYSIZE 4
53
54 STATISTIC(NumInstrumentedReads, "Number of instrumented reads");
55 STATISTIC(NumInstrumentedWrites, "Number of instrumented writes");
56 // STATISTIC(NumInstrumentedVtableWrites, "Number of vtable ptr writes");
57 // STATISTIC(NumInstrumentedVtableReads, "Number of vtable ptr reads");
58
59 STATISTIC(NumOmittedReadsBeforeWrite,
60           "Number of reads ignored due to following writes");
61 STATISTIC(NumOmittedReadsFromConstantGlobals,
62           "Number of reads from constant globals");
63 STATISTIC(NumOmittedReadsFromVtable, "Number of vtable reads");
64 STATISTIC(NumOmittedNonCaptured, "Number of accesses ignored due to capturing");
65
66 Type * Int8Ty;
67 Type * Int16Ty;
68 Type * Int32Ty;
69 Type * Int64Ty;
70 Type * OrdTy;
71
72 Type * Int8PtrTy;
73 Type * Int16PtrTy;
74 Type * Int32PtrTy;
75 Type * Int64PtrTy;
76
77 Type * VoidTy;
78
79 Constant * CDSLoad[FUNCARRAYSIZE];
80 Constant * CDSStore[FUNCARRAYSIZE];
81 Constant * CDSAtomicLoad[FUNCARRAYSIZE];
82 Constant * CDSAtomicStore[FUNCARRAYSIZE];
83 Constant * CDSAtomicRMW[AtomicRMWInst::LAST_BINOP + 1][FUNCARRAYSIZE];
84 Constant * CDSAtomicCAS[FUNCARRAYSIZE];
85 Constant * CDSAtomicThreadFence;
86
87 bool start = false;
88
89 int getAtomicOrderIndex(AtomicOrdering order){
90   switch (order) {
91     case AtomicOrdering::Monotonic: 
92       return (int)AtomicOrderingCABI::relaxed;
93 //  case AtomicOrdering::Consume:         // not specified yet
94 //    return AtomicOrderingCABI::consume;
95     case AtomicOrdering::Acquire: 
96       return (int)AtomicOrderingCABI::acquire;
97     case AtomicOrdering::Release: 
98       return (int)AtomicOrderingCABI::release;
99     case AtomicOrdering::AcquireRelease: 
100       return (int)AtomicOrderingCABI::acq_rel;
101     case AtomicOrdering::SequentiallyConsistent: 
102       return (int)AtomicOrderingCABI::seq_cst;
103     default:
104       // unordered or Not Atomic
105       return -1;
106   }
107 }
108
109 int getTypeSize(Type* type) {
110   if (type==Int32PtrTy) {
111     return sizeof(int)*8;
112   } else if (type==Int8PtrTy) {
113     return sizeof(char)*8;
114   } else if (type==Int16PtrTy) {
115     return sizeof(short)*8;
116   } else if (type==Int64PtrTy) {
117     return sizeof(long long int)*8;
118   } else {
119     return sizeof(void*)*8;
120   }
121
122   return -1;
123 }
124
125 static int sizetoindex(int size) {
126   switch(size) {
127     case 8:     return 0;
128     case 16:    return 1;
129     case 32:    return 2;
130     case 64:    return 3;
131   }
132   return -1;
133 }
134
135 namespace {
136   struct CDSPass : public FunctionPass {
137     static char ID;
138     CDSPass() : FunctionPass(ID) {}
139     bool runOnFunction(Function &F) override; 
140
141   private:
142     void initializeCallbacks(Module &M);
143     bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL);
144     bool instrumentAtomic(Instruction *I);
145     void chooseInstructionsToInstrument(SmallVectorImpl<Instruction *> &Local,
146                                       SmallVectorImpl<Instruction *> &All,
147                                       const DataLayout &DL);
148     bool addrPointsToConstantData(Value *Addr);
149   };
150 }
151
152 void CDSPass::initializeCallbacks(Module &M) {
153   LLVMContext &Ctx = M.getContext();
154
155   Int8Ty  = Type::getInt8Ty(Ctx);
156   Int16Ty = Type::getInt16Ty(Ctx);
157   Int32Ty = Type::getInt32Ty(Ctx);
158   Int64Ty = Type::getInt64Ty(Ctx);
159   OrdTy = Type::getInt32Ty(Ctx);
160
161   Int8PtrTy  = Type::getInt8PtrTy(Ctx);
162   Int16PtrTy = Type::getInt16PtrTy(Ctx);
163   Int32PtrTy = Type::getInt32PtrTy(Ctx);
164   Int64PtrTy = Type::getInt64PtrTy(Ctx);
165
166   VoidTy = Type::getVoidTy(Ctx);
167   
168
169   // Get the function to call from our untime library.
170   for (unsigned i = 0; i < FUNCARRAYSIZE; i++) {
171     const unsigned ByteSize = 1U << i;
172     const unsigned BitSize = ByteSize * 8;
173 //    errs() << BitSize << "\n";
174     std::string ByteSizeStr = utostr(ByteSize);
175     std::string BitSizeStr = utostr(BitSize);
176
177     Type *Ty = Type::getIntNTy(Ctx, BitSize);
178     Type *PtrTy = Ty->getPointerTo();
179
180     // uint8_t cds_atomic_load8 (void * obj, int atomic_index)
181     // void cds_atomic_store8 (void * obj, int atomic_index, uint8_t val)
182     SmallString<32> LoadName("cds_load" + BitSizeStr);
183     SmallString<32> StoreName("cds_store" + BitSizeStr);
184     SmallString<32> AtomicLoadName("cds_atomic_load" + BitSizeStr);
185     SmallString<32> AtomicStoreName("cds_atomic_store" + BitSizeStr);
186
187 //    CDSLoad[i]  = M.getOrInsertFunction(LoadName, Ty, PtrTy);
188 //    CDSStore[i] = M.getOrInsertFunction(StoreName, VoidTy, PtrTy, Ty);
189     CDSLoad[i]  = M.getOrInsertFunction(LoadName, VoidTy, PtrTy);
190     CDSStore[i] = M.getOrInsertFunction(StoreName, VoidTy, PtrTy);
191     CDSAtomicLoad[i]  = M.getOrInsertFunction(AtomicLoadName, Ty, PtrTy, OrdTy);
192     CDSAtomicStore[i] = M.getOrInsertFunction(AtomicStoreName, VoidTy, PtrTy, OrdTy, Ty);
193
194     for (int op = AtomicRMWInst::FIRST_BINOP; op <= AtomicRMWInst::LAST_BINOP; ++op) {
195       CDSAtomicRMW[op][i] = nullptr;
196       std::string NamePart;
197
198       if (op == AtomicRMWInst::Xchg)
199         NamePart = "_exchange";
200       else if (op == AtomicRMWInst::Add) 
201         NamePart = "_fetch_add";
202       else if (op == AtomicRMWInst::Sub)
203         NamePart = "_fetch_sub";
204       else if (op == AtomicRMWInst::And)
205         NamePart = "_fetch_and";
206       else if (op == AtomicRMWInst::Or)
207         NamePart = "_fetch_or";
208       else if (op == AtomicRMWInst::Xor)
209         NamePart = "_fetch_xor";
210       else
211         continue;
212
213       SmallString<32> AtomicRMWName("cds_atomic" + NamePart + BitSizeStr);
214       CDSAtomicRMW[op][i] = M.getOrInsertFunction(AtomicRMWName, Ty, PtrTy, OrdTy, Ty);
215     }
216
217     // only supportes strong version
218     SmallString<32> AtomicCASName("cds_atomic_compare_exchange" + BitSizeStr);    
219     CDSAtomicCAS[i]   = M.getOrInsertFunction(AtomicCASName, Ty, PtrTy, Ty, Ty, OrdTy, OrdTy);
220   }
221
222   CDSAtomicThreadFence = M.getOrInsertFunction("cds_atomic_thread_fence", VoidTy, OrdTy);
223 }
224
225 static bool isVtableAccess(Instruction *I) {
226   if (MDNode *Tag = I->getMetadata(LLVMContext::MD_tbaa))
227     return Tag->isTBAAVtableAccess();
228   return false;
229 }
230
231 static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) {
232   // Peel off GEPs and BitCasts.
233   Addr = Addr->stripInBoundsOffsets();
234
235   if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
236     if (GV->hasSection()) {
237       StringRef SectionName = GV->getSection();
238       // Check if the global is in the PGO counters section.
239       auto OF = Triple(M->getTargetTriple()).getObjectFormat();
240       if (SectionName.endswith(
241               getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false)))
242         return false;
243     }
244
245     // Check if the global is private gcov data.
246     if (GV->getName().startswith("__llvm_gcov") ||
247         GV->getName().startswith("__llvm_gcda"))
248       return false;
249   }
250
251   // Do not instrument acesses from different address spaces; we cannot deal
252   // with them.
253   if (Addr) {
254     Type *PtrTy = cast<PointerType>(Addr->getType()->getScalarType());
255     if (PtrTy->getPointerAddressSpace() != 0)
256       return false;
257   }
258
259   return true;
260 }
261
262 bool CDSPass::addrPointsToConstantData(Value *Addr) {
263   // If this is a GEP, just analyze its pointer operand.
264   if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Addr))
265     Addr = GEP->getPointerOperand();
266
267   if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
268     if (GV->isConstant()) {
269       // Reads from constant globals can not race with any writes.
270       NumOmittedReadsFromConstantGlobals++;
271       return true;
272     }
273   } else if (LoadInst *L = dyn_cast<LoadInst>(Addr)) {
274     if (isVtableAccess(L)) {
275       // Reads from a vtable pointer can not race with any writes.
276       NumOmittedReadsFromVtable++;
277       return true;
278     }
279   }
280   return false;
281 }
282
283 bool CDSPass::runOnFunction(Function &F) {
284   if (F.getName() == "main") {
285     F.setName("user_main");
286     errs() << "main replaced by user_main\n";
287
288     initializeCallbacks( *F.getParent() );
289
290     SmallVector<Instruction*, 8> AllLoadsAndStores;
291     SmallVector<Instruction*, 8> LocalLoadsAndStores;
292     SmallVector<Instruction*, 8> AtomicAccesses;
293
294     std::vector<Instruction *> worklist;
295
296     bool Res = false;
297     const DataLayout &DL = F.getParent()->getDataLayout();
298   
299     errs() << "Before\n";
300     F.dump();
301
302     for (auto &B : F) {
303       for (auto &I : B) {
304         if ( (&I)->isAtomic() ) {
305           AtomicAccesses.push_back(&I);
306         } else if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
307           LocalLoadsAndStores.push_back(&I);
308         } else if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
309           // not implemented yet
310         }
311       }
312       chooseInstructionsToInstrument(LocalLoadsAndStores, AllLoadsAndStores, DL);
313     }
314
315     for (auto Inst : AllLoadsAndStores) {
316 //      Res |= instrumentLoadOrStore(Inst, DL);
317 //      errs() << "load and store are not replaced\n";
318     }
319
320     for (auto Inst : AtomicAccesses) {
321       Res |= instrumentAtomic(Inst);
322     } 
323
324     if (Res) {
325       errs() << F.getName(); 
326       errs() << " has above instructions replaced\n";
327     }
328   }
329 //        errs() << "After\n";
330 //        F.dump();
331   
332   return false;
333 }
334
335 void CDSPass::chooseInstructionsToInstrument(
336     SmallVectorImpl<Instruction *> &Local, SmallVectorImpl<Instruction *> &All,
337     const DataLayout &DL) {
338   SmallPtrSet<Value*, 8> WriteTargets;
339   // Iterate from the end.
340   for (Instruction *I : reverse(Local)) {
341     if (StoreInst *Store = dyn_cast<StoreInst>(I)) {
342       Value *Addr = Store->getPointerOperand();
343       if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
344         continue;
345       WriteTargets.insert(Addr);
346     } else {
347       LoadInst *Load = cast<LoadInst>(I);
348       Value *Addr = Load->getPointerOperand();
349       if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
350         continue;
351       if (WriteTargets.count(Addr)) {
352         // We will write to this temp, so no reason to analyze the read.
353         NumOmittedReadsBeforeWrite++;
354         continue;
355       }
356       if (addrPointsToConstantData(Addr)) {
357         // Addr points to some constant data -- it can not race with any writes.
358         continue;
359       }
360     }
361     Value *Addr = isa<StoreInst>(*I)
362         ? cast<StoreInst>(I)->getPointerOperand()
363         : cast<LoadInst>(I)->getPointerOperand();
364     if (isa<AllocaInst>(GetUnderlyingObject(Addr, DL)) &&
365         !PointerMayBeCaptured(Addr, true, true)) {
366       // The variable is addressable but not captured, so it cannot be
367       // referenced from a different thread and participate in a data race
368       // (see llvm/Analysis/CaptureTracking.h for details).
369       NumOmittedNonCaptured++;
370       continue;
371     }
372     All.push_back(I);
373   }
374   Local.clear();
375 }
376
377
378 bool CDSPass::instrumentLoadOrStore(Instruction *I,
379                                             const DataLayout &DL) {
380   IRBuilder<> IRB(I);
381   bool IsWrite = isa<StoreInst>(*I);
382   Value *Addr = IsWrite
383       ? cast<StoreInst>(I)->getPointerOperand()
384       : cast<LoadInst>(I)->getPointerOperand();
385
386   // swifterror memory addresses are mem2reg promoted by instruction selection.
387   // As such they cannot have regular uses like an instrumentation function and
388   // it makes no sense to track them as memory.
389   if (Addr->isSwiftError())
390     return false;
391
392   int size = getTypeSize(Addr->getType());
393   int index = sizetoindex(size);
394
395 //  not supported by CDS yet
396 /*  if (IsWrite && isVtableAccess(I)) {
397     LLVM_DEBUG(dbgs() << "  VPTR : " << *I << "\n");
398     Value *StoredValue = cast<StoreInst>(I)->getValueOperand();
399     // StoredValue may be a vector type if we are storing several vptrs at once.
400     // In this case, just take the first element of the vector since this is
401     // enough to find vptr races.
402     if (isa<VectorType>(StoredValue->getType()))
403       StoredValue = IRB.CreateExtractElement(
404           StoredValue, ConstantInt::get(IRB.getInt32Ty(), 0));
405     if (StoredValue->getType()->isIntegerTy())
406       StoredValue = IRB.CreateIntToPtr(StoredValue, IRB.getInt8PtrTy());
407     // Call TsanVptrUpdate.
408     IRB.CreateCall(TsanVptrUpdate,
409                    {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()),
410                     IRB.CreatePointerCast(StoredValue, IRB.getInt8PtrTy())});
411     NumInstrumentedVtableWrites++;
412     return true;
413   }
414
415   if (!IsWrite && isVtableAccess(I)) {
416     IRB.CreateCall(TsanVptrLoad,
417                    IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()));
418     NumInstrumentedVtableReads++;
419     return true;
420   }
421 */
422
423   Value *OnAccessFunc = nullptr;
424   OnAccessFunc = IsWrite ? CDSStore[index] : CDSLoad[index];
425   
426   Type *ArgType = IRB.CreatePointerCast(Addr, Addr->getType())->getType();
427
428   if ( ArgType != Int8PtrTy && ArgType != Int16PtrTy && 
429                 ArgType != Int32PtrTy && ArgType != Int64PtrTy ) {
430         //errs() << "A load or store of type ";
431         //errs() << *ArgType;
432         //errs() << " is passed in\n";
433         return false;   // if other types of load or stores are passed in
434   }
435   IRB.CreateCall(OnAccessFunc, IRB.CreatePointerCast(Addr, Addr->getType()));
436   if (IsWrite) NumInstrumentedWrites++;
437   else         NumInstrumentedReads++;
438   return true;
439 }
440
441
442 bool CDSPass::instrumentAtomic(Instruction * I) {
443   IRBuilder<> IRB(I);
444   // LLVMContext &Ctx = IRB.getContext();
445
446   if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
447     int atomic_order_index = getAtomicOrderIndex(SI->getOrdering());
448
449     Value *val = SI->getValueOperand();
450     Value *ptr = SI->getPointerOperand();
451     Value *order = ConstantInt::get(OrdTy, atomic_order_index);
452     Value *args[] = {ptr, order, val};
453
454     int size=getTypeSize(ptr->getType());
455     int index=sizetoindex(size);
456
457     Instruction* funcInst=CallInst::Create(CDSAtomicStore[index], args,"");
458     ReplaceInstWithInst(SI, funcInst);
459     errs() << "Store replaced\n";
460   } else if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
461     int atomic_order_index = getAtomicOrderIndex(LI->getOrdering());
462
463     Value *ptr = LI->getPointerOperand();
464     Value *order = ConstantInt::get(OrdTy, atomic_order_index);
465     Value *args[] = {ptr, order};
466
467     int size=getTypeSize(ptr->getType());
468     int index=sizetoindex(size);
469
470     Instruction* funcInst=CallInst::Create(CDSAtomicLoad[index], args, "");
471     ReplaceInstWithInst(LI, funcInst);
472     errs() << "Load Replaced\n";
473   } else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) {
474     int atomic_order_index = getAtomicOrderIndex(RMWI->getOrdering());
475
476     Value *val = RMWI->getValOperand();
477     Value *ptr = RMWI->getPointerOperand();
478     Value *order = ConstantInt::get(OrdTy, atomic_order_index);
479     Value *args[] = {ptr, order, val};
480
481     int size = getTypeSize(ptr->getType());
482     int index = sizetoindex(size);
483
484     Instruction* funcInst = CallInst::Create(CDSAtomicRMW[RMWI->getOperation()][index], args, "");
485     ReplaceInstWithInst(RMWI, funcInst);
486     errs() << RMWI->getOperationName(RMWI->getOperation());
487     errs() << " replaced\n";
488   } else if (AtomicCmpXchgInst *CASI = dyn_cast<AtomicCmpXchgInst>(I)) {
489     IRBuilder<> IRB(CASI);
490
491     Value *Addr = CASI->getPointerOperand();
492
493     int size = getTypeSize(Addr->getType());
494     int index = sizetoindex(size);
495     const unsigned ByteSize = 1U << index;
496     const unsigned BitSize = ByteSize * 8;
497     Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
498     Type *PtrTy = Ty->getPointerTo();
499
500     Value *CmpOperand = IRB.CreateBitOrPointerCast(CASI->getCompareOperand(), Ty);
501     Value *NewOperand = IRB.CreateBitOrPointerCast(CASI->getNewValOperand(), Ty);
502
503     int atomic_order_index_succ = getAtomicOrderIndex(CASI->getSuccessOrdering());
504     int atomic_order_index_fail = getAtomicOrderIndex(CASI->getFailureOrdering());
505     Value *order_succ = ConstantInt::get(OrdTy, atomic_order_index_succ);
506     Value *order_fail = ConstantInt::get(OrdTy, atomic_order_index_fail);
507
508     Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy),
509                      CmpOperand, NewOperand,
510                      order_succ, order_fail};
511
512     CallInst *funcInst = IRB.CreateCall(CDSAtomicCAS[index], Args);
513     Value *Success = IRB.CreateICmpEQ(funcInst, CmpOperand);
514
515     Value *OldVal = funcInst;
516     Type *OrigOldValTy = CASI->getNewValOperand()->getType();
517     if (Ty != OrigOldValTy) {
518       // The value is a pointer, so we need to cast the return value.
519       OldVal = IRB.CreateIntToPtr(funcInst, OrigOldValTy);
520     }
521
522     Value *Res =
523       IRB.CreateInsertValue(UndefValue::get(CASI->getType()), OldVal, 0);
524     Res = IRB.CreateInsertValue(Res, Success, 1);
525
526     I->replaceAllUsesWith(Res);
527     I->eraseFromParent();
528   } else if (FenceInst *FI = dyn_cast<FenceInst>(I)) {
529     int atomic_order_index = getAtomicOrderIndex(FI->getOrdering());
530     Value *order = ConstantInt::get(OrdTy, atomic_order_index);
531     Value *Args[] = {order};
532
533     CallInst *funcInst = CallInst::Create(CDSAtomicThreadFence, Args);
534     ReplaceInstWithInst(FI, funcInst);
535     errs() << "Thread Fences replaced\n";
536   }
537   return true;
538 }
539
540
541
542 char CDSPass::ID = 0;
543
544 // Automatically enable the pass.
545 // http://adriansampson.net/blog/clangpass.html
546 static void registerCDSPass(const PassManagerBuilder &,
547                          legacy::PassManagerBase &PM) {
548   PM.add(new CDSPass());
549 }
550 static RegisterStandardPasses 
551         RegisterMyPass(PassManagerBuilder::EP_EarlyAsPossible,
552 registerCDSPass);