Don't build a vector of returns. Just modify the Function in the loop.
[oota-llvm.git] / lib / CodeGen / StackProtector.cpp
1 //===-- StackProtector.cpp - Stack Protector Insertion --------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass inserts stack protectors into functions which need them. A variable
11 // with a random value in it is stored onto the stack before the local variables
12 // are allocated. Upon exiting the block, the stored value is checked. If it's
13 // changed, then there was some sort of violation and the program aborts.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #define DEBUG_TYPE "stack-protector"
18 #include "llvm/CodeGen/Passes.h"
19 #include "llvm/Constants.h"
20 #include "llvm/DerivedTypes.h"
21 #include "llvm/Function.h"
22 #include "llvm/Instructions.h"
23 #include "llvm/Intrinsics.h"
24 #include "llvm/Module.h"
25 #include "llvm/Pass.h"
26 #include "llvm/ADT/APInt.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Target/TargetData.h"
29 #include "llvm/Target/TargetLowering.h"
30 using namespace llvm;
31
32 // SSPBufferSize - The lower bound for a buffer to be considered for stack
33 // smashing protection.
34 static cl::opt<unsigned>
35 SSPBufferSize("stack-protector-buffer-size", cl::init(8),
36               cl::desc("The lower bound for a buffer to be considered for "
37                        "stack smashing protection."));
38
39 namespace {
40   class VISIBILITY_HIDDEN StackProtector : public FunctionPass {
41     /// Level - The level of stack protection.
42     SSP::StackProtectorLevel Level;
43
44     /// TLI - Keep a pointer of a TargetLowering to consult for determining
45     /// target type sizes.
46     const TargetLowering *TLI;
47
48     Function *F;
49     Module *M;
50
51     /// InsertStackProtectors - Insert code into the prologue and epilogue of
52     /// the function.
53     ///
54     ///  - The prologue code loads and stores the stack guard onto the stack.
55     ///  - The epilogue checks the value stored in the prologue against the
56     ///    original value. It calls __stack_chk_fail if they differ.
57     bool InsertStackProtectors();
58
59     /// CreateFailBB - Create a basic block to jump to when the stack protector
60     /// check fails.
61     BasicBlock *CreateFailBB();
62
63     /// RequiresStackProtector - Check whether or not this function needs a
64     /// stack protector based upon the stack protector level.
65     bool RequiresStackProtector() const;
66   public:
67     static char ID;             // Pass identification, replacement for typeid.
68     StackProtector() : FunctionPass(&ID), Level(SSP::OFF), TLI(0) {}
69     StackProtector(SSP::StackProtectorLevel lvl, const TargetLowering *tli)
70       : FunctionPass(&ID), Level(lvl), TLI(tli) {}
71
72     virtual bool runOnFunction(Function &Fn);
73   };
74 } // end anonymous namespace
75
76 char StackProtector::ID = 0;
77 static RegisterPass<StackProtector>
78 X("stack-protector", "Insert stack protectors");
79
80 FunctionPass *llvm::createStackProtectorPass(SSP::StackProtectorLevel lvl,
81                                              const TargetLowering *tli) {
82   return new StackProtector(lvl, tli);
83 }
84
85 bool StackProtector::runOnFunction(Function &Fn) {
86   F = &Fn;
87   M = F->getParent();
88
89   if (!RequiresStackProtector()) return false;
90   
91   return InsertStackProtectors();
92 }
93
94 /// InsertStackProtectors - Insert code into the prologue and epilogue of the
95 /// function.
96 ///
97 ///  - The prologue code loads and stores the stack guard onto the stack.
98 ///  - The epilogue checks the value stored in the prologue against the original
99 ///    value. It calls __stack_chk_fail if they differ.
100 bool StackProtector::InsertStackProtectors() {
101   Constant *StackGuardVar = 0;  // The global variable for the stack guard.
102   BasicBlock *FailBB = 0;       // The basic block to jump to if check fails.
103
104   // Loop through the basic blocks that have return instructions. Convert this:
105   //
106   //   return:
107   //     ...
108   //     ret ...
109   //
110   // into this:
111   //
112   //   return:
113   //     ...
114   //     %1 = load __stack_chk_guard
115   //     %2 = load <stored stack guard>
116   //     %3 = cmp i1 %1, %2
117   //     br i1 %3, label %SP_return, label %CallStackCheckFailBlk
118   //
119   //   SP_return:
120   //     ret ...
121   //
122   //   CallStackCheckFailBlk:
123   //     call void @__stack_chk_fail()
124   //     unreachable
125   //
126   for (Function::iterator I = F->begin(), E = F->end(); I != E; ++I) {
127     BasicBlock *BB = I;
128
129     if (isa<ReturnInst>(BB->getTerminator())) {
130       // Create the basic block to jump to when the guard check fails.
131       if (!FailBB)
132         FailBB = CreateFailBB();
133
134       if (!StackGuardVar)
135         StackGuardVar =
136           M->getOrInsertGlobal("__stack_chk_guard",
137                                PointerType::getUnqual(Type::Int8Ty));
138
139       ReturnInst *RI = cast<ReturnInst>(BB->getTerminator());
140       Function::iterator InsPt = BB; ++InsPt; // Insertion point for new BB.
141       ++I;
142
143       // Split the basic block before the return instruction.
144       BasicBlock *NewBB = BB->splitBasicBlock(RI, "SP_return");
145
146       // Move the newly created basic block to the point right after the old basic
147       // block so that it's in the "fall through" position.
148       NewBB->removeFromParent();
149       F->getBasicBlockList().insert(InsPt, NewBB);
150
151       // Generate the stack protector instructions in the old basic block.
152       LoadInst *LI1 = new LoadInst(StackGuardVar, "", false, BB);
153       CallInst *CI = CallInst::
154         Create(Intrinsic::getDeclaration(M, Intrinsic::stackprotector_check),
155                "", BB);
156       ICmpInst *Cmp = new ICmpInst(CmpInst::ICMP_EQ, CI, LI1, "", BB);
157       BranchInst::Create(NewBB, FailBB, Cmp, BB);
158     }
159   }
160
161   // Return if we didn't modify any basic blocks. I.e., there are no return
162   // statements in the function.
163   if (!FailBB) return false;
164
165   // Insert code into the entry block that stores the __stack_chk_guard variable
166   // onto the stack.
167   BasicBlock &Entry = F->getEntryBlock();
168   Instruction *InsertPt = &Entry.front();
169
170   LoadInst *LI = new LoadInst(StackGuardVar, "StackGuard", false, InsertPt);
171   CallInst::
172     Create(Intrinsic::getDeclaration(M, Intrinsic::stackprotector_create),
173            LI, "", InsertPt);
174
175   return true;
176 }
177
178 /// CreateFailBB - Create a basic block to jump to when the stack protector
179 /// check fails.
180 BasicBlock *StackProtector::CreateFailBB() {
181   BasicBlock *FailBB = BasicBlock::Create("CallStackCheckFailBlk", F);
182   Constant *StackChkFail =
183     M->getOrInsertFunction("__stack_chk_fail", Type::VoidTy, NULL);
184   CallInst::Create(StackChkFail, "", FailBB);
185   new UnreachableInst(FailBB);
186   return FailBB;
187 }
188
189 /// RequiresStackProtector - Check whether or not this function needs a stack
190 /// protector based upon the stack protector level. The heuristic we use is to
191 /// add a guard variable to functions that call alloca, and functions with
192 /// buffers larger than 8 bytes.
193 bool StackProtector::RequiresStackProtector() const {
194   switch (Level) {
195   default: return false;
196   case SSP::ALL: return true;
197   case SSP::SOME: {
198     const TargetData *TD = TLI->getTargetData();
199
200     for (Function::iterator I = F->begin(), E = F->end(); I != E; ++I) {
201       BasicBlock *BB = I;
202
203       for (BasicBlock::iterator
204              II = BB->begin(), IE = BB->end(); II != IE; ++II)
205         if (AllocaInst *AI = dyn_cast<AllocaInst>(II)) {
206           if (!AI->isArrayAllocation()) continue; // Only care about arrays.
207
208           if (ConstantInt *CI = dyn_cast<ConstantInt>(AI->getArraySize())) {
209             const Type *Ty = AI->getAllocatedType();
210             uint64_t TySize = TD->getABITypeSize(Ty);
211
212             // If an array has more than 8 bytes of allocated space, then we
213             // emit stack protectors.
214             if (SSPBufferSize <= TySize * CI->getZExtValue())
215               return true;
216           } else {
217             // This is a call to alloca with a variable size. Default to adding
218             // stack protectors.
219             return true;
220           }
221         }
222     }
223
224     return false;
225   }
226   }
227 }