From b1b97833aeaf8a7ef6dd3b314a502a1521b02657 Mon Sep 17 00:00:00 2001 From: Andrew Trick Date: Wed, 29 Aug 2012 21:46:38 +0000 Subject: [PATCH] Preserve branch profile metadata during switch formation. Patch by Michael Ilseman! This fixes SimplifyCFGOpt::FoldValueComparisonIntoPredecessors to preserve metata when folding conditional branches into switches. void foo(int x) { if (x == 0) bar(1); else if (__builtin_expect(x == 10, 1)) bar(2); else if (x == 20) bar(3); } CFG: B0 | \ | X0 B10 | \ | X10 B20 | \ E X20 Merge B0-B10: w(B0-X0) = w(B0-X0)*sum-weights(B10) = w(B0-X0) * (w(B10-X10) + w(B10-B20)) w(B0-X10) = w(B0-B10) * w(B10-X10) w(B0-B20) = w(B0-B10) * w(B10-B20) B0 __ | \ \ | X10 X0 B20 | \ E X20 Merge B0-B20: w(B0-X0) = w(B0-X0) * sum-weights(B20) = w(B0-X0) * (w(B20-E) + w(B20-X20)) w(B0-X10) = w(B0-X10) * sum-weights(B20) = ... w(B0-X20) = w(B0-B20) * w(B20-X20) w(B0-E) = w(B0-B20) * w(B20-E) git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@162868 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/Utils/SimplifyCFG.cpp | 154 +++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp index 06a61a802b5..dddc18fcefa 100644 --- a/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/lib/Transforms/Utils/SimplifyCFG.cpp @@ -615,6 +615,9 @@ SimplifyEqualityComparisonWithOnlyPredecessor(TerminatorInst *TI, assert(ThisVal && "This isn't a value comparison!!"); if (ThisVal != PredVal) return false; // Different predicates. + // TODO: Preserve branch weight metadata, similarly to how + // FoldValueComparisonIntoPredecessors preserves it. + // Find out information about when control will move from Pred to TI's block. std::vector PredCases; BasicBlock *PredDef = GetValueEqualityComparisonCases(Pred->getTerminator(), @@ -738,6 +741,67 @@ static int ConstantIntSortPredicate(const void *P1, const void *P2) { return -1; } +static inline bool HasBranchWeights(const Instruction* I) { + MDNode* ProfMD = I->getMetadata(LLVMContext::MD_prof); + if (ProfMD && ProfMD->getOperand(0)) + if (MDString* MDS = dyn_cast(ProfMD->getOperand(0))) + return MDS->getString().equals("branch_weights"); + + return false; +} + +/// Tries to get a branch weight for the given instruction, returns NULL if it +/// can't. Pos starts at 0. +static ConstantInt* GetWeight(Instruction* I, int Pos) { + MDNode* ProfMD = I->getMetadata(LLVMContext::MD_prof); + if (ProfMD && ProfMD->getOperand(0)) { + if (MDString* MDS = dyn_cast(ProfMD->getOperand(0))) { + if (MDS->getString().equals("branch_weights")) { + assert(ProfMD->getNumOperands() >= 3); + return dyn_cast(ProfMD->getOperand(1 + Pos)); + } + } + } + + return 0; +} + +/// Scale the given weights based on the new TI's metadata. Scaling is done by +/// multiplying every weight by the sum of the successor's weights. +static void ScaleWeights(Instruction* STI, MutableArrayRef Weights) { + // Sum the successor's weights + assert(HasBranchWeights(STI)); + unsigned Scale = 0; + MDNode* ProfMD = STI->getMetadata(LLVMContext::MD_prof); + for (unsigned i = 1; i < ProfMD->getNumOperands(); ++i) { + ConstantInt* CI = dyn_cast(ProfMD->getOperand(i)); + assert(CI); + Scale += CI->getValue().getZExtValue(); + } + + // Skip default, as it's replaced during the folding + for (unsigned i = 1; i < Weights.size(); ++i) { + Weights[i] *= Scale; + } +} + +/// Sees if any of the weights are too big for a uint32_t, and halves all the +/// weights if any are. +static void FitWeights(MutableArrayRef Weights) { + bool Halve = false; + for (unsigned i = 0; i < Weights.size(); ++i) + if (Weights[i] > UINT_MAX) { + Halve = true; + break; + } + + if (! Halve) + return; + + for (unsigned i = 0; i < Weights.size(); ++i) + Weights[i] /= 2; +} + /// FoldValueComparisonIntoPredecessors - The specified terminator is a value /// equality comparison instruction (either a switch or a branch on "X == c"). /// See if any of the predecessors of the terminator block are value comparisons @@ -770,6 +834,55 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, // build. SmallVector NewSuccessors; + // Update the branch weight metadata along the way + SmallVector Weights; + uint64_t PredDefaultWeight = 0; + bool PredHasWeights = HasBranchWeights(PTI); + bool SuccHasWeights = HasBranchWeights(TI); + + if (PredHasWeights) { + MDNode* MD = PTI->getMetadata(LLVMContext::MD_prof); + assert(MD); + for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) { + ConstantInt* CI = dyn_cast(MD->getOperand(i)); + assert(CI); + Weights.push_back(CI->getValue().getZExtValue()); + } + + // If the predecessor is a conditional eq, then swap the default weight + // to be the first entry. + if (BranchInst* BI = dyn_cast(PTI)) { + assert(Weights.size() == 2); + ICmpInst *ICI = cast(BI->getCondition()); + + if (ICI->getPredicate() == ICmpInst::ICMP_EQ) { + std::swap(Weights.front(), Weights.back()); + } + } + + PredDefaultWeight = Weights.front(); + } else if (SuccHasWeights) { + // If there are no predecessor weights but there are successor weights, + // populate Weights with 1, which will later be scaled to the sum of + // successor's weights + Weights.assign(1 + PredCases.size(), 1); + PredDefaultWeight = 1; + } + + uint64_t SuccDefaultWeight = 0; + if (SuccHasWeights) { + int Index = 0; + if (BranchInst* BI = dyn_cast(TI)) { + ICmpInst* ICI = dyn_cast(BI->getCondition()); + assert(ICI); + + if (ICI->getPredicate() == ICmpInst::ICMP_EQ) + Index = 1; + } + + SuccDefaultWeight = GetWeight(TI, Index)->getValue().getZExtValue(); + } + if (PredDefault == BB) { // If this is the default destination from PTI, only the edges in TI // that don't occur in PTI, or that branch to BB will be activated. @@ -780,6 +893,12 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, else { // The default destination is BB, we don't need explicit targets. std::swap(PredCases[i], PredCases.back()); + + if (PredHasWeights) { + std::swap(Weights[i+1], Weights.back()); + Weights.pop_back(); + } + PredCases.pop_back(); --i; --e; } @@ -790,14 +909,35 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, PredDefault = BBDefault; NewSuccessors.push_back(BBDefault); } + + if (SuccHasWeights) { + ScaleWeights(TI, Weights); + Weights.front() *= SuccDefaultWeight; + } else if (PredHasWeights) { + Weights.front() /= (1 + BBCases.size()); + } + for (unsigned i = 0, e = BBCases.size(); i != e; ++i) if (!PTIHandled.count(BBCases[i].Value) && BBCases[i].Dest != BBDefault) { PredCases.push_back(BBCases[i]); NewSuccessors.push_back(BBCases[i].Dest); + if (SuccHasWeights) { + Weights.push_back(PredDefaultWeight * + GetWeight(TI, i)->getValue().getZExtValue()); + } else if (PredHasWeights) { + // Split the old default's weight amongst the children + assert(PredDefaultWeight != 0); + Weights.push_back(PredDefaultWeight / (1 + BBCases.size())); + } } } else { + // FIXME: preserve branch weight metadata, similarly to the 'then' + // above. For now, drop it. + PredHasWeights = false; + SuccHasWeights = false; + // If this is not the default destination from PSI, only the edges // in SI that occur in PSI with a destination of BB will be // activated. @@ -851,6 +991,17 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, for (unsigned i = 0, e = PredCases.size(); i != e; ++i) NewSI->addCase(PredCases[i].Value, PredCases[i].Dest); + if (PredHasWeights || SuccHasWeights) { + // Halve the weights if any of them cannot fit in an uint32_t + FitWeights(Weights); + + SmallVector MDWeights(Weights.begin(), Weights.end()); + + NewSI->setMetadata(LLVMContext::MD_prof, + MDBuilder(BB->getContext()). + createBranchWeights(MDWeights)); + } + EraseTerminatorInstAndDCECond(PTI); // Okay, last check. If BB is still a successor of PSI, then we must @@ -2349,6 +2500,9 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, const TargetData *TD, // transformation. A switch with one value is just an cond branch. if (ExtraCase && Values.size() < 2) return false; + // TODO: Preserve branch weight metadata, similarly to how + // FoldValueComparisonIntoPredecessors preserves it. + // Figure out which block is which destination. BasicBlock *DefaultBB = BI->getSuccessor(1); BasicBlock *EdgeBB = BI->getSuccessor(0); -- 2.34.1