1 //===-- ForwardControlFlowIntegrity.cpp: Forward-Edge CFI -----------------===//
3 // This file is distributed under the University of Illinois Open Source
4 // License. See LICENSE.TXT for details.
6 //===----------------------------------------------------------------------===//
9 /// \brief A pass that instruments code with fast checks for indirect calls and
10 /// hooks for a function to check violations.
12 //===----------------------------------------------------------------------===//
14 #define DEBUG_TYPE "cfi"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/JumpInstrTableInfo.h"
19 #include "llvm/CodeGen/ForwardControlFlowIntegrity.h"
20 #include "llvm/CodeGen/JumpInstrTables.h"
21 #include "llvm/CodeGen/Passes.h"
22 #include "llvm/IR/Attributes.h"
23 #include "llvm/IR/CallSite.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/GlobalValue.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/InlineAsm.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/LLVMContext.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/IR/Operator.h"
34 #include "llvm/IR/Type.h"
35 #include "llvm/IR/Verifier.h"
36 #include "llvm/Pass.h"
37 #include "llvm/Support/CommandLine.h"
38 #include "llvm/Support/Debug.h"
39 #include "llvm/Support/raw_ostream.h"
43 STATISTIC(NumCFIIndirectCalls,
44 "Number of indirect call sites rewritten by the CFI pass");
46 char ForwardControlFlowIntegrity::ID = 0;
47 INITIALIZE_PASS_BEGIN(ForwardControlFlowIntegrity, "forward-cfi",
48 "Control-Flow Integrity", true, true)
49 INITIALIZE_PASS_DEPENDENCY(JumpInstrTableInfo);
50 INITIALIZE_PASS_DEPENDENCY(JumpInstrTables);
51 INITIALIZE_PASS_END(ForwardControlFlowIntegrity, "forward-cfi",
52 "Control-Flow Integrity", true, true)
54 ModulePass *llvm::createForwardControlFlowIntegrityPass() {
55 return new ForwardControlFlowIntegrity();
58 ModulePass *llvm::createForwardControlFlowIntegrityPass(
59 JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing,
60 StringRef CFIFuncName) {
61 return new ForwardControlFlowIntegrity(JTT, CFIType, CFIEnforcing,
65 // Checks to see if a given CallSite is making an indirect call, including
66 // cases where the indirect call is made through a bitcast.
67 static bool isIndirectCall(CallSite &CS) {
68 if (CS.getCalledFunction())
71 // Check the value to see if it is merely a bitcast of a function. In
72 // this case, it will translate to a direct function call in the resulting
73 // assembly, so we won't treat it as an indirect call here.
74 const Value *V = CS.getCalledValue();
75 if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) {
76 return !(CE->isCast() && isa<Function>(CE->getOperand(0)));
79 // Otherwise, since we know it's a call, it must be an indirect call
83 static const char cfi_failure_func_name[] = "__llvm_cfi_pointer_warning";
85 ForwardControlFlowIntegrity::ForwardControlFlowIntegrity()
86 : ModulePass(ID), IndirectCalls(), JTType(JumpTable::Single),
87 CFIType(CFIntegrity::Sub), CFIEnforcing(false), CFIFuncName("") {
88 initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry());
91 ForwardControlFlowIntegrity::ForwardControlFlowIntegrity(
92 JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing,
93 std::string CFIFuncName)
94 : ModulePass(ID), IndirectCalls(), JTType(JTT), CFIType(CFIType),
95 CFIEnforcing(CFIEnforcing), CFIFuncName(CFIFuncName) {
96 initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry());
99 ForwardControlFlowIntegrity::~ForwardControlFlowIntegrity() {}
101 void ForwardControlFlowIntegrity::getAnalysisUsage(AnalysisUsage &AU) const {
102 AU.addRequired<JumpInstrTableInfo>();
103 AU.addRequired<JumpInstrTables>();
106 void ForwardControlFlowIntegrity::getIndirectCalls(Module &M) {
107 // To get the indirect calls, we iterate over all functions and iterate over
108 // the list of basic blocks in each. We extract a total list of indirect calls
109 // before modifying any of them, since our modifications will modify the list
111 for (Function &F : M) {
112 for (BasicBlock &BB : F) {
113 for (Instruction &I : BB) {
115 if (!(CS && isIndirectCall(CS)))
118 Value *CalledValue = CS.getCalledValue();
120 // Don't rewrite this instruction if the indirect call is actually just
121 // inline assembly, since our transformation will generate an invalid
122 // module in that case.
123 if (isa<InlineAsm>(CalledValue))
126 IndirectCalls.push_back(&I);
132 void ForwardControlFlowIntegrity::updateIndirectCalls(Module &M,
134 Type *Int64Ty = Type::getInt64Ty(M.getContext());
135 for (Instruction *I : IndirectCalls) {
137 Value *CalledValue = CS.getCalledValue();
139 // Get the function type for this call and look it up in the tables.
140 Type *VTy = CalledValue->getType();
141 PointerType *PTy = dyn_cast<PointerType>(VTy);
142 Type *EltTy = PTy->getElementType();
143 FunctionType *FunTy = dyn_cast<FunctionType>(EltTy);
144 FunctionType *TransformedTy = JumpInstrTables::transformType(JTType, FunTy);
145 ++NumCFIIndirectCalls;
146 Constant *JumpTableStart = nullptr;
147 Constant *JumpTableMask = nullptr;
148 Constant *JumpTableSize = nullptr;
150 // Some call sites have function types that don't correspond to any
151 // address-taken function in the module. This happens when function pointers
152 // are passed in from external code.
153 auto it = CFIT.find(TransformedTy);
154 if (it == CFIT.end()) {
155 // In this case, make sure that the function pointer will change by
156 // setting the mask and the start to be 0 so that the transformed
158 JumpTableStart = ConstantInt::get(Int64Ty, 0);
159 JumpTableMask = ConstantInt::get(Int64Ty, 0);
160 JumpTableSize = ConstantInt::get(Int64Ty, 0);
162 JumpTableStart = it->second.StartValue;
163 JumpTableMask = it->second.MaskValue;
164 JumpTableSize = it->second.Size;
167 rewriteFunctionPointer(M, I, CalledValue, JumpTableStart, JumpTableMask,
174 bool ForwardControlFlowIntegrity::runOnModule(Module &M) {
175 JumpInstrTableInfo *JITI = &getAnalysis<JumpInstrTableInfo>();
176 Type *Int64Ty = Type::getInt64Ty(M.getContext());
177 Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext());
179 // JumpInstrTableInfo stores information about the alignment of each entry.
180 // The alignment returned by JumpInstrTableInfo is alignment in bytes, not
182 ByteAlignment = JITI->entryByteAlignment();
183 LogByteAlignment = llvm::Log2_64(ByteAlignment);
185 // Set up tables for control-flow integrity based on information about the
186 // jump-instruction tables.
188 for (const auto &KV : JITI->getTables()) {
189 uint64_t Size = static_cast<uint64_t>(KV.second.size());
190 uint64_t TableSize = NextPowerOf2(Size);
192 int64_t MaskValue = ((TableSize << LogByteAlignment) - 1) & -ByteAlignment;
193 Constant *JumpTableMaskValue = ConstantInt::get(Int64Ty, MaskValue);
194 Constant *JumpTableSize = ConstantInt::get(Int64Ty, Size);
196 // The base of the table is defined to be the first jumptable function in
198 Function *First = KV.second.begin()->second;
199 Constant *JumpTableStartValue = ConstantExpr::getBitCast(First, VoidPtrTy);
200 CFIT[KV.first].StartValue = JumpTableStartValue;
201 CFIT[KV.first].MaskValue = JumpTableMaskValue;
202 CFIT[KV.first].Size = JumpTableSize;
211 addWarningFunction(M);
214 // Update the instructions with the check and the indirect jump through our
216 updateIndirectCalls(M, CFIT);
221 void ForwardControlFlowIntegrity::addWarningFunction(Module &M) {
222 PointerType *CharPtrTy = Type::getInt8PtrTy(M.getContext());
224 // Get the type of the Warning Function: void (i8*, i8*),
225 // where the first argument is the name of the function in which the violation
226 // occurs, and the second is the function pointer that violates CFI.
227 SmallVector<Type *, 2> WarningFunArgs;
228 WarningFunArgs.push_back(CharPtrTy);
229 WarningFunArgs.push_back(CharPtrTy);
230 FunctionType *WarningFunTy =
231 FunctionType::get(Type::getVoidTy(M.getContext()), WarningFunArgs, false);
233 if (!CFIFuncName.empty()) {
234 Constant *FailureFun = M.getOrInsertFunction(CFIFuncName, WarningFunTy);
236 report_fatal_error("Could not get or insert the function specified by"
239 // The default warning function swallows the warning and lets the call
240 // continue, since there's no generic way for it to print out this
242 Function *WarningFun = M.getFunction(cfi_failure_func_name);
245 Function::Create(WarningFunTy, GlobalValue::LinkOnceAnyLinkage,
246 cfi_failure_func_name, &M);
250 BasicBlock::Create(M.getContext(), "entry", WarningFun, 0);
251 ReturnInst::Create(M.getContext(), Entry);
255 void ForwardControlFlowIntegrity::rewriteFunctionPointer(
256 Module &M, Instruction *I, Value *FunPtr, Constant *JumpTableStart,
257 Constant *JumpTableMask, Constant *JumpTableSize) {
258 IRBuilder<> TempBuilder(I);
260 Type *OrigFunType = FunPtr->getType();
262 BasicBlock *CurBB = cast<BasicBlock>(I->getParent());
263 Function *CurF = cast<Function>(CurBB->getParent());
264 Type *Int64Ty = Type::getInt64Ty(M.getContext());
266 Value *TI = TempBuilder.CreatePtrToInt(FunPtr, Int64Ty);
267 Value *TStartInt = TempBuilder.CreatePtrToInt(JumpTableStart, Int64Ty);
269 Value *NewFunPtr = nullptr;
270 Value *Check = nullptr;
272 case CFIntegrity::Sub: {
273 // This is the subtract, mask, and add version.
274 // Subtract from the base.
275 Value *Sub = TempBuilder.CreateSub(TI, TStartInt);
277 // Mask the difference to force this to be a table offset.
278 Value *And = TempBuilder.CreateAnd(Sub, JumpTableMask);
280 // Add it back to the base.
281 Value *Result = TempBuilder.CreateAdd(And, TStartInt);
283 // Convert it back into a function pointer that we can call.
284 NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType);
287 case CFIntegrity::Ror: {
288 // This is the subtract and rotate version.
289 // Rotate right by the alignment value. The optimizer should recognize
290 // this sequence as a rotation.
292 // This cast is safe, since unsigned is always a subset of uint64_t.
293 uint64_t LogByteAlignment64 = static_cast<uint64_t>(LogByteAlignment);
294 Constant *RightShift = ConstantInt::get(Int64Ty, LogByteAlignment64);
295 Constant *LeftShift = ConstantInt::get(Int64Ty, 64 - LogByteAlignment64);
297 // Subtract from the base.
298 Value *Sub = TempBuilder.CreateSub(TI, TStartInt);
300 // Create the equivalent of a rotate-right instruction.
301 Value *Shr = TempBuilder.CreateLShr(Sub, RightShift);
302 Value *Shl = TempBuilder.CreateShl(Sub, LeftShift);
303 Value *Or = TempBuilder.CreateOr(Shr, Shl);
305 // Perform unsigned comparison to check for inclusion in the table.
306 Check = TempBuilder.CreateICmpULT(Or, JumpTableSize);
310 case CFIntegrity::Add: {
311 // This is the mask and add version.
312 // Mask the function pointer to turn it into an offset into the table.
313 Value *And = TempBuilder.CreateAnd(TI, JumpTableMask);
315 // Then or this offset to the base and get the pointer value.
316 Value *Result = TempBuilder.CreateAdd(And, TStartInt);
318 // Convert it back into a function pointer that we can call.
319 NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType);
325 // If a check hasn't been added (in the rotation version), then check to see
326 // if it's the same as the original function. This check determines whether
327 // or not we call the CFI failure function.
329 Check = TempBuilder.CreateICmpEQ(NewFunPtr, FunPtr);
330 BasicBlock *InvalidPtrBlock =
331 BasicBlock::Create(M.getContext(), "invalid.ptr", CurF, 0);
332 BasicBlock *ContinuationBB = CurBB->splitBasicBlock(I);
334 // Remove the unconditional branch that connects the two blocks.
335 TerminatorInst *TermInst = CurBB->getTerminator();
336 TermInst->eraseFromParent();
338 // Add a conditional branch that depends on the Check above.
339 BranchInst::Create(ContinuationBB, InvalidPtrBlock, Check, CurBB);
341 // Call the warning function for this pointer, then continue.
342 Instruction *BI = BranchInst::Create(ContinuationBB, InvalidPtrBlock);
343 insertWarning(M, InvalidPtrBlock, BI, FunPtr);
345 // Modify the instruction to call this value.
347 CS.setCalledFunction(NewFunPtr);
351 void ForwardControlFlowIntegrity::insertWarning(Module &M, BasicBlock *Block,
352 Instruction *I, Value *FunPtr) {
353 Function *ParentFun = cast<Function>(Block->getParent());
355 // Get the function to call right before the instruction.
356 Function *WarningFun = nullptr;
357 if (CFIFuncName.empty()) {
358 WarningFun = M.getFunction(cfi_failure_func_name);
360 WarningFun = M.getFunction(CFIFuncName);
363 assert(WarningFun && "Could not find the CFI failure function");
365 Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext());
367 IRBuilder<> WarningInserter(I);
368 // Create a mergeable GlobalVariable containing the name of the function.
369 Value *ParentNameGV =
370 WarningInserter.CreateGlobalString(ParentFun->getName());
371 Value *ParentNamePtr = WarningInserter.CreateBitCast(ParentNameGV, VoidPtrTy);
372 Value *FunVoidPtr = WarningInserter.CreateBitCast(FunPtr, VoidPtrTy);
373 WarningInserter.CreateCall2(WarningFun, ParentNamePtr, FunVoidPtr);