From: Chris Lattner Date: Tue, 27 Sep 2005 21:18:17 +0000 (+0000) Subject: Completely rewrite 'correct' eh support. This changes how setjmp insertion X-Git-Url: http://demsky.eecs.uci.edu/git/?a=commitdiff_plain;h=f4e6c3a69b3cf3a381811900c0f3768f626e90b8;p=oota-llvm.git Completely rewrite 'correct' eh support. This changes how setjmp insertion is performed so it is only at most once per function that contains an invoke instead of once per invoke in the function. This patch has the following perks: 1. It fixes PR631, which complains about slowness. 2. If fixes PR240, which complains about non-volatile vars being live across setjmp/longjmps. 3. It improves (but does not fix) the jmpbuf alignment issue on itanium by not forcing the jmpbufs to always be 8-bytes off the alignment of the structure. 4. It speeds up 253.perlbmk from 338s to 13.70s (a 25x improvement!), making us now about 4% faster than GCC. Further improvements are also possible. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@23477 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Transforms/Utils/LowerInvoke.cpp b/lib/Transforms/Utils/LowerInvoke.cpp index b0d8fb8a308..54724b5f1ac 100644 --- a/lib/Transforms/Utils/LowerInvoke.cpp +++ b/lib/Transforms/Utils/LowerInvoke.cpp @@ -41,13 +41,17 @@ #include "llvm/Module.h" #include "llvm/Pass.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/ADT/Statistic.h" #include "llvm/Support/CommandLine.h" #include using namespace llvm; namespace { - Statistic<> NumLowered("lowerinvoke", "Number of invoke & unwinds replaced"); + Statistic<> NumInvokes("lowerinvoke", "Number of invokes replaced"); + Statistic<> NumUnwinds("lowerinvoke", "Number of unwinds replaced"); + Statistic<> NumSpilled("lowerinvoke", + "Number of registers live across unwind edges"); cl::opt ExpensiveEHSupport("enable-correct-eh-support", cl::desc("Make the -lowerinvoke pass insert expensive, but correct, EH code")); @@ -65,10 +69,14 @@ namespace { public: bool doInitialization(Module &M); bool runOnFunction(Function &F); + private: void createAbortMessage(); void writeAbortMessage(Instruction *IB); bool insertCheapEHSupport(Function &F); + void splitLiveRangesLiveAcrossInvokes(std::vector &Invokes); + void rewriteExpensiveInvoke(InvokeInst *II, unsigned InvokeNo, + AllocaInst *InvokeNum, SwitchInst *CatchSwitch); bool insertExpensiveEHSupport(Function &F); }; @@ -97,9 +105,9 @@ bool LowerInvoke::doInitialization(Module &M) { { // The type is recursive, so use a type holder. std::vector Elements; + Elements.push_back(JmpBufTy); OpaqueType *OT = OpaqueType::get(); Elements.push_back(PointerType::get(OT)); - Elements.push_back(JmpBufTy); PATypeHolder JBLType(StructType::get(Elements)); OT->refineAbstractTypeTo(JBLType.get()); // Complete the cycle. JBLinkTy = JBLType.get(); @@ -220,7 +228,7 @@ bool LowerInvoke::insertCheapEHSupport(Function &F) { // Remove the invoke instruction now. BB->getInstList().erase(II); - ++NumLowered; Changed = true; + ++NumInvokes; Changed = true; } else if (UnwindInst *UI = dyn_cast(BB->getTerminator())) { // Insert a new call to write(2, AbortMessage, AbortMessageLength); writeAbortMessage(UI); @@ -236,163 +244,316 @@ bool LowerInvoke::insertCheapEHSupport(Function &F) { // Remove the unwind instruction now. BB->getInstList().erase(UI); - ++NumLowered; Changed = true; + ++NumUnwinds; Changed = true; } return Changed; } -bool LowerInvoke::insertExpensiveEHSupport(Function &F) { - bool Changed = false; +/// rewriteExpensiveInvoke - Insert code and hack the function to replace the +/// specified invoke instruction with a call. +void LowerInvoke::rewriteExpensiveInvoke(InvokeInst *II, unsigned InvokeNo, + AllocaInst *InvokeNum, + SwitchInst *CatchSwitch) { + ConstantUInt *InvokeNoC = ConstantUInt::get(Type::UIntTy, InvokeNo); + + // Insert a store of the invoke num before the invoke and store zero into the + // location afterward. + new StoreInst(InvokeNoC, InvokeNum, true, II); // volatile + new StoreInst(Constant::getNullValue(Type::UIntTy), InvokeNum, false, + II->getNormalDest()->begin()); // nonvolatile. + + // Add a switch case to our unwind block. + CatchSwitch->addCase(InvokeNoC, II->getUnwindDest()); + + // Insert a normal call instruction. + std::string Name = II->getName(); II->setName(""); + CallInst *NewCall = new CallInst(II->getCalledValue(), + std::vector(II->op_begin()+3, + II->op_end()), Name, + II); + NewCall->setCallingConv(II->getCallingConv()); + II->replaceAllUsesWith(NewCall); + + // Replace the invoke with an uncond branch. + new BranchInst(II->getNormalDest(), NewCall->getParent()); + II->eraseFromParent(); +} - // If a function uses invoke, we have an alloca for the jump buffer. - AllocaInst *JmpBuf = 0; +/// MarkBlocksLiveIn - Insert BB and all of its predescessors into LiveBBs until +/// we reach blocks we've already seen. +static void MarkBlocksLiveIn(BasicBlock *BB, std::set &LiveBBs) { + if (!LiveBBs.insert(BB).second) return; // already been here. + + for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) + MarkBlocksLiveIn(*PI, LiveBBs); +} - // If this function contains an unwind instruction, two blocks get added: one - // to actually perform the longjmp, and one to terminate the program if there - // is no handler. - BasicBlock *UnwindBlock = 0, *TermBlock = 0; - std::vector JBPtrs; +// First thing we need to do is scan the whole function for values that are +// live across unwind edges. Each value that is live across an unwind edge +// we spill into a stack location, guaranteeing that there is nothing live +// across the unwind edge. This process also splits all critical edges +// coming out of invoke's. +void LowerInvoke:: +splitLiveRangesLiveAcrossInvokes(std::vector &Invokes) { + // First step, split all critical edges from invoke instructions. + for (unsigned i = 0, e = Invokes.size(); i != e; ++i) { + InvokeInst *II = Invokes[i]; + SplitCriticalEdge(II, 0, this); + SplitCriticalEdge(II, 1, this); + assert(!isa(II->getNormalDest()) && + !isa(II->getUnwindDest()) && + "critical edge splitting left single entry phi nodes?"); + } - for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) - if (InvokeInst *II = dyn_cast(BB->getTerminator())) { - if (JmpBuf == 0) - JmpBuf = new AllocaInst(JBLinkTy, 0, "jblink", F.begin()->begin()); - - // On the entry to the invoke, we must install our JmpBuf as the top of - // the stack. - LoadInst *OldEntry = new LoadInst(JBListHead, "oldehlist", II); - - // Store this old value as our 'next' field, and store our alloca as the - // current jblist. - std::vector Idx; - Idx.push_back(Constant::getNullValue(Type::IntTy)); - Idx.push_back(ConstantUInt::get(Type::UIntTy, 0)); - Value *NextFieldPtr = new GetElementPtrInst(JmpBuf, Idx, "NextField", II); - new StoreInst(OldEntry, NextFieldPtr, II); - new StoreInst(JmpBuf, JBListHead, II); - - // Call setjmp, passing in the address of the jmpbuffer. - Idx[1] = ConstantUInt::get(Type::UIntTy, 1); - Value *JmpBufPtr = new GetElementPtrInst(JmpBuf, Idx, "TheJmpBuf", II); - Value *SJRet = new CallInst(SetJmpFn, JmpBufPtr, "sjret", II); - - // Compare the return value to zero. - Value *IsNormal = BinaryOperator::create(Instruction::SetEQ, SJRet, - Constant::getNullValue(SJRet->getType()), - "notunwind", II); - // Create the receiver block if there is a critical edge to the normal - // destination. - SplitCriticalEdge(II, 0, this); + Function *F = Invokes.back()->getParent()->getParent(); + + // To avoid having to handle incoming arguments specially, we lower each arg + // to a copy instruction in the entry block. This ensure that the argument + // value itself cannot be live across the entry block. + BasicBlock::iterator AfterAllocaInsertPt = F->begin()->begin(); + while (isa(AfterAllocaInsertPt) && + isa(cast(AfterAllocaInsertPt)->getArraySize())) + ++AfterAllocaInsertPt; + for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); + AI != E; ++AI) { + CastInst *NC = new CastInst(AI, AI->getType(), AI->getName()+".tmp", + AfterAllocaInsertPt); + AI->replaceAllUsesWith(NC); + NC->setOperand(0, AI); + } + + // Finally, scan the code looking for instructions with bad live ranges. + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) + for (BasicBlock::iterator II = BB->begin(), E = BB->end(); II != E; ++II) { + // Ignore obvious cases we don't have to handle. In particular, most + // instructions either have no uses or only have a single use inside the + // current block. Ignore them quickly. + Instruction *Inst = II; + if (Inst->use_empty()) continue; + if (Inst->hasOneUse() && + cast(Inst->use_back())->getParent() == BB && + !isa(Inst->use_back())) continue; - // There should not be any PHI nodes in II->getNormalDest() now. It has - // a single predecessor, so any PHI nodes are unneeded. Remove them now - // by replacing them with their single input value. - assert(II->getNormalDest()->getSinglePredecessor() && - "Split crit edge doesn't have a single predecessor!"); - - BasicBlock::iterator InsertLoc = II->getNormalDest()->begin(); - while (PHINode *PN = dyn_cast(InsertLoc)) { - PN->replaceAllUsesWith(PN->getIncomingValue(0)); - PN->eraseFromParent(); - InsertLoc = II->getNormalDest()->begin(); + // Avoid iterator invalidation by copying users to a temporary vector. + std::vector Users; + for (Value::use_iterator UI = Inst->use_begin(), E = Inst->use_end(); + UI != E; ++UI) { + Instruction *User = cast(*UI); + if (User->getParent() != BB || isa(User)) + Users.push_back(User); } - - // Insert a normal call instruction on the normal execution path. - std::string Name = II->getName(); II->setName(""); - CallInst *NewCall = new CallInst(II->getCalledValue(), - std::vector(II->op_begin()+3, - II->op_end()), Name, - InsertLoc); - NewCall->setCallingConv(II->getCallingConv()); - II->replaceAllUsesWith(NewCall); - - // If we got this far, then no exception was thrown and we can pop our - // jmpbuf entry off. - new StoreInst(OldEntry, JBListHead, InsertLoc); - - // Now we change the invoke into a branch instruction. - new BranchInst(II->getNormalDest(), II->getUnwindDest(), IsNormal, II); - - // Remove the InvokeInst now. - BB->getInstList().erase(II); - ++NumLowered; Changed = true; - } else if (UnwindInst *UI = dyn_cast(BB->getTerminator())) { - if (UnwindBlock == 0) { - // Create two new blocks, the unwind block and the terminate block. Add - // them at the end of the function because they are not hot. - UnwindBlock = new BasicBlock("unwind", &F); - TermBlock = new BasicBlock("unwinderror", &F); - - // Insert return instructions. These really should be "barrier"s, as - // they are unreachable. - new ReturnInst(F.getReturnType() == Type::VoidTy ? 0 : - Constant::getNullValue(F.getReturnType()), UnwindBlock); - new ReturnInst(F.getReturnType() == Type::VoidTy ? 0 : - Constant::getNullValue(F.getReturnType()), TermBlock); + // Scan all of the uses and see if the live range is live across an unwind + // edge. If we find a use live across an invoke edge, create an alloca + // and spill the value. + AllocaInst *SpillLoc = 0; + std::set InvokesWithStoreInserted; + + // Find all of the blocks that this value is live in. + std::set LiveBBs; + LiveBBs.insert(Inst->getParent()); + while (!Users.empty()) { + Instruction *U = Users.back(); + Users.pop_back(); + + BasicBlock *UseBlock; + if (!isa(U)) { + MarkBlocksLiveIn(U->getParent(), LiveBBs); + } else { + // Uses for a PHI node occur in their predecessor block. + PHINode *PN = cast(U); + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) == Inst) + MarkBlocksLiveIn(PN->getIncomingBlock(i), LiveBBs); + } + } + + // Now that we know all of the blocks that this thing is live in, see if + // it includes any of the unwind locations. + bool NeedsSpill = false; + for (unsigned i = 0, e = Invokes.size(); i != e; ++i) { + BasicBlock *UnwindBlock = Invokes[i]->getUnwindDest(); + if (UnwindBlock != BB && LiveBBs.count(UnwindBlock)) { + NeedsSpill = true; + } } - // Load the JBList, if it's null, then there was no catch! - LoadInst *Ptr = new LoadInst(JBListHead, "ehlist", UI); - Value *NotNull = BinaryOperator::create(Instruction::SetNE, Ptr, - Constant::getNullValue(Ptr->getType()), - "notnull", UI); - new BranchInst(UnwindBlock, TermBlock, NotNull, UI); - - // Remember the loaded value so we can insert the PHI node as needed. - JBPtrs.push_back(Ptr); - - // Remove the UnwindInst now. - BB->getInstList().erase(UI); - ++NumLowered; Changed = true; + // If we decided we need a spill, do it. + if (NeedsSpill) { + ++NumSpilled; + DemoteRegToStack(*Inst, true); + } } +} + +bool LowerInvoke::insertExpensiveEHSupport(Function &F) { + std::vector Returns; + std::vector Unwinds; + std::vector Invokes; - // If an unwind instruction was inserted, we need to set up the Unwind and - // term blocks. - if (UnwindBlock) { - // In the unwind block, we know that the pointer coming in on the JBPtrs - // list are non-null. - Instruction *RI = UnwindBlock->getTerminator(); - - Value *RecPtr; - if (JBPtrs.size() == 1) - RecPtr = JBPtrs[0]; - else { - // If there is more than one unwind in this function, make a PHI node to - // merge in all of the loaded values. - PHINode *PN = new PHINode(JBPtrs[0]->getType(), "jbptrs", RI); - for (unsigned i = 0, e = JBPtrs.size(); i != e; ++i) - PN->addIncoming(JBPtrs[i], JBPtrs[i]->getParent()); - RecPtr = PN; + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + if (ReturnInst *RI = dyn_cast(BB->getTerminator())) { + // Remember all return instructions in case we insert an invoke into this + // function. + Returns.push_back(RI); + } else if (InvokeInst *II = dyn_cast(BB->getTerminator())) { + Invokes.push_back(II); + } else if (UnwindInst *UI = dyn_cast(BB->getTerminator())) { + Unwinds.push_back(UI); } - // Now that we have a pointer to the whole record, remove the entry from the - // JBList. + if (Unwinds.empty() && Invokes.empty()) return false; + + NumInvokes += Invokes.size(); + NumUnwinds += Unwinds.size(); + + // If we have an invoke instruction, insert a setjmp that dominates all + // invokes. After the setjmp, use a cond branch that goes to the original + // code path on zero, and to a designated 'catch' block of nonzero. + Value *OldJmpBufPtr = 0; + if (!Invokes.empty()) { + // First thing we need to do is scan the whole function for values that are + // live across unwind edges. Each value that is live across an unwind edge + // we spill into a stack location, guaranteeing that there is nothing live + // across the unwind edge. This process also splits all critical edges + // coming out of invoke's. + splitLiveRangesLiveAcrossInvokes(Invokes); + + BasicBlock *EntryBB = F.begin(); + + // Create an alloca for the incoming jump buffer ptr and the new jump buffer + // that needs to be restored on all exits from the function. This is an + // alloca because the value needs to be live across invokes. + AllocaInst *JmpBuf = + new AllocaInst(JBLinkTy, 0, "jblink", F.begin()->begin()); + std::vector Idx; Idx.push_back(Constant::getNullValue(Type::IntTy)); - Idx.push_back(ConstantUInt::get(Type::UIntTy, 0)); - Value *NextFieldPtr = new GetElementPtrInst(RecPtr, Idx, "NextField", RI); - Value *NextRec = new LoadInst(NextFieldPtr, "NextRecord", RI); - new StoreInst(NextRec, JBListHead, RI); - - // Now that we popped the top of the JBList, get a pointer to the jmpbuf and - // longjmp. - Idx[1] = ConstantUInt::get(Type::UIntTy, 1); - Idx[0] = new GetElementPtrInst(RecPtr, Idx, "JmpBuf", RI); - Idx[1] = ConstantInt::get(Type::IntTy, 1); - new CallInst(LongJmpFn, Idx, "", RI); - - // Now we set up the terminate block. - RI = TermBlock->getTerminator(); - - // Insert a new call to write(2, AbortMessage, AbortMessageLength); - writeAbortMessage(RI); - - // Insert a call to abort() - (new CallInst(AbortFn, std::vector(), "", RI))->setTailCall(); + Idx.push_back(ConstantUInt::get(Type::UIntTy, 1)); + OldJmpBufPtr = new GetElementPtrInst(JmpBuf, Idx, "OldBuf", + EntryBB->getTerminator()); + + // Copy the JBListHead to the alloca. + Value *OldBuf = new LoadInst(JBListHead, "oldjmpbufptr", true, + EntryBB->getTerminator()); + new StoreInst(OldBuf, OldJmpBufPtr, true, EntryBB->getTerminator()); + + // Add the new jumpbuf to the list. + new StoreInst(JmpBuf, JBListHead, true, EntryBB->getTerminator()); + + // Create the catch block. The catch block is basically a big switch + // statement that goes to all of the invoke catch blocks. + BasicBlock *CatchBB = new BasicBlock("setjmp.catch", &F); + + // Create an alloca which keeps track of which invoke is currently + // executing. For normal calls it contains zero. + AllocaInst *InvokeNum = new AllocaInst(Type::UIntTy, 0, "invokenum", + EntryBB->begin()); + new StoreInst(ConstantInt::get(Type::UIntTy, 0), InvokeNum, true, + EntryBB->getTerminator()); + + // Insert a load in the Catch block, and a switch on its value. By default, + // we go to a block that just does an unwind (which is the correct action + // for a standard call). + BasicBlock *UnwindBB = new BasicBlock("unwindbb", &F); + Unwinds.push_back(new UnwindInst(UnwindBB)); + + Value *CatchLoad = new LoadInst(InvokeNum, "invoke.num", true, CatchBB); + SwitchInst *CatchSwitch = + new SwitchInst(CatchLoad, UnwindBB, Invokes.size(), CatchBB); + + // Now that things are set up, insert the setjmp call itself. + + // Split the entry block to insert the conditional branch for the setjmp. + BasicBlock *ContBlock = EntryBB->splitBasicBlock(EntryBB->getTerminator(), + "setjmp.cont"); + + Idx[1] = ConstantUInt::get(Type::UIntTy, 0); + Value *JmpBufPtr = new GetElementPtrInst(JmpBuf, Idx, "TheJmpBuf", + EntryBB->getTerminator()); + Value *SJRet = new CallInst(SetJmpFn, JmpBufPtr, "sjret", + EntryBB->getTerminator()); + + // Compare the return value to zero. + Value *IsNormal = BinaryOperator::createSetEQ(SJRet, + Constant::getNullValue(SJRet->getType()), + "notunwind", EntryBB->getTerminator()); + // Nuke the uncond branch. + EntryBB->getTerminator()->eraseFromParent(); + + // Put in a new condbranch in its place. + new BranchInst(ContBlock, CatchBB, IsNormal, EntryBB); + + // At this point, we are all set up, rewrite each invoke instruction. + for (unsigned i = 0, e = Invokes.size(); i != e; ++i) + rewriteExpensiveInvoke(Invokes[i], i+1, InvokeNum, CatchSwitch); } - return Changed; + // We know that there is at least one unwind. + + // Create three new blocks, the block to load the jmpbuf ptr and compare + // against null, the block to do the longjmp, and the error block for if it + // is null. Add them at the end of the function because they are not hot. + BasicBlock *UnwindHandler = new BasicBlock("dounwind", &F); + BasicBlock *UnwindBlock = new BasicBlock("unwind", &F); + BasicBlock *TermBlock = new BasicBlock("unwinderror", &F); + + // If this function contains an invoke, restore the old jumpbuf ptr. + Value *BufPtr; + if (OldJmpBufPtr) { + // Before the return, insert a copy from the saved value to the new value. + BufPtr = new LoadInst(OldJmpBufPtr, "oldjmpbufptr", UnwindHandler); + new StoreInst(BufPtr, JBListHead, UnwindHandler); + } else { + BufPtr = new LoadInst(JBListHead, "ehlist", UnwindHandler); + } + + // Load the JBList, if it's null, then there was no catch! + Value *NotNull = BinaryOperator::createSetNE(BufPtr, + Constant::getNullValue(BufPtr->getType()), + "notnull", UnwindHandler); + new BranchInst(UnwindBlock, TermBlock, NotNull, UnwindHandler); + + // Create the block to do the longjmp. + // Get a pointer to the jmpbuf and longjmp. + std::vector Idx; + Idx.push_back(Constant::getNullValue(Type::IntTy)); + Idx.push_back(ConstantUInt::get(Type::UIntTy, 0)); + Idx[0] = new GetElementPtrInst(BufPtr, Idx, "JmpBuf", UnwindBlock); + Idx[1] = ConstantInt::get(Type::IntTy, 1); + new CallInst(LongJmpFn, Idx, "", UnwindBlock); + new UnreachableInst(UnwindBlock); + + // Set up the term block ("throw without a catch"). + new UnreachableInst(TermBlock); + + // Insert a new call to write(2, AbortMessage, AbortMessageLength); + writeAbortMessage(TermBlock->getTerminator()); + + // Insert a call to abort() + (new CallInst(AbortFn, std::vector(), "", + TermBlock->getTerminator()))->setTailCall(); + + + // Replace all unwinds with a branch to the unwind handler. + for (unsigned i = 0, e = Unwinds.size(); i != e; ++i) { + new BranchInst(UnwindHandler, Unwinds[i]); + Unwinds[i]->eraseFromParent(); + } + + // Finally, for any returns from this function, if this function contains an + // invoke, restore the old jmpbuf pointer to its input value. + if (OldJmpBufPtr) { + for (unsigned i = 0, e = Returns.size(); i != e; ++i) { + ReturnInst *R = Returns[i]; + + // Before the return, insert a copy from the saved value to the new value. + Value *OldBuf = new LoadInst(OldJmpBufPtr, "oldjmpbufptr", true, R); + new StoreInst(OldBuf, JBListHead, true, R); + } + } + + return true; } bool LowerInvoke::runOnFunction(Function &F) {