From: Manman Ren Date: Sat, 15 Sep 2012 00:39:57 +0000 (+0000) Subject: PGO: preserve branch-weight metadata when simplifying two branches with a common X-Git-Url: http://demsky.eecs.uci.edu/git/?a=commitdiff_plain;h=062986c2f0a020c0652010e3c8549653397048f4;p=oota-llvm.git PGO: preserve branch-weight metadata when simplifying two branches with a common destination. Updated previous implementation to fix a case not covered: // PBI: br i1 %x, TrueDest, BB // BI: br i1 %y, TrueDest, FalseDest The other case was handled correctly. // PBI: br i1 %x, BB, FalseDest // BI: br i1 %y, TrueDest, FalseDest Also tried to use 64-bit arithmetic instead of APInt with scale to simplify the computation. Let me know if you have other opinions about this. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@163954 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp index a9d74cdf867..e7b5bc779a3 100644 --- a/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/lib/Transforms/Utils/SimplifyCFG.cpp @@ -1658,7 +1658,7 @@ static bool SimplifyCondBranchToTwoReturns(BranchInst *BI, /// parameters and return true, or returns false if no or invalid metadata was /// found. static bool ExtractBranchMetadata(BranchInst *BI, - APInt &ProbTrue, APInt &ProbFalse) { + uint64_t &ProbTrue, uint64_t &ProbFalse) { assert(BI->isConditional() && "Looking for probabilities on unconditional branch?"); MDNode *ProfileData = BI->getMetadata(LLVMContext::MD_prof); @@ -1666,35 +1666,11 @@ static bool ExtractBranchMetadata(BranchInst *BI, ConstantInt *CITrue = dyn_cast(ProfileData->getOperand(1)); ConstantInt *CIFalse = dyn_cast(ProfileData->getOperand(2)); if (!CITrue || !CIFalse) return false; - ProbTrue = CITrue->getValue(); - ProbFalse = CIFalse->getValue(); - assert(ProbTrue.getBitWidth() == 32 && ProbFalse.getBitWidth() == 32 && - "Branch probability metadata must be 32-bit integers"); + ProbTrue = CITrue->getValue().getZExtValue(); + ProbFalse = CIFalse->getValue().getZExtValue(); return true; } -/// MultiplyAndLosePrecision - Multiplies A and B, then returns the result. In -/// the event of overflow, logically-shifts all four inputs right until the -/// multiply fits. -static APInt MultiplyAndLosePrecision(APInt &A, APInt &B, APInt &C, APInt &D, - unsigned &BitsLost) { - BitsLost = 0; - bool Overflow = false; - APInt Result = A.umul_ov(B, Overflow); - if (Overflow) { - APInt MaxB = APInt::getMaxValue(A.getBitWidth()).udiv(A); - do { - B = B.lshr(1); - ++BitsLost; - } while (B.ugt(MaxB)); - A = A.lshr(BitsLost); - C = C.lshr(BitsLost); - D = D.lshr(BitsLost); - Result = A * B; - } - return Result; -} - /// checkCSEInPredecessor - Return true if the given instruction is available /// in its predecessor block. If yes, the instruction will be removed. /// @@ -1919,14 +1895,53 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI) { New, "or.cond")); PBI->setCondition(NewCond); + uint64_t PredTrueWeight, PredFalseWeight, SuccTrueWeight, SuccFalseWeight; + bool PredHasWeights = ExtractBranchMetadata(PBI, PredTrueWeight, + PredFalseWeight); + bool SuccHasWeights = ExtractBranchMetadata(BI, SuccTrueWeight, + SuccFalseWeight); + SmallVector NewWeights; + if (PBI->getSuccessor(0) == BB) { + if (PredHasWeights && SuccHasWeights) { + // PBI: br i1 %x, BB, FalseDest + // BI: br i1 %y, TrueDest, FalseDest + //TrueWeight is TrueWeight for PBI * TrueWeight for BI. + NewWeights.push_back(PredTrueWeight * SuccTrueWeight); + //FalseWeight is FalseWeight for PBI * TotalWeight for BI + + // TrueWeight for PBI * FalseWeight for BI. + // We assume that total weights of a BranchInst can fit into 32 bits. + // Therefore, we will not have overflow using 64-bit arithmetic. + NewWeights.push_back(PredFalseWeight * (SuccFalseWeight + + SuccTrueWeight) + PredTrueWeight * SuccFalseWeight); + } AddPredecessorToBlock(TrueDest, PredBlock, BB); PBI->setSuccessor(0, TrueDest); } if (PBI->getSuccessor(1) == BB) { + if (PredHasWeights && SuccHasWeights) { + // PBI: br i1 %x, TrueDest, BB + // BI: br i1 %y, TrueDest, FalseDest + //TrueWeight is TrueWeight for PBI * TotalWeight for BI + + // FalseWeight for PBI * TrueWeight for BI. + NewWeights.push_back(PredTrueWeight * (SuccFalseWeight + + SuccTrueWeight) + PredFalseWeight * SuccTrueWeight); + //FalseWeight is FalseWeight for PBI * FalseWeight for BI. + NewWeights.push_back(PredFalseWeight * SuccFalseWeight); + } AddPredecessorToBlock(FalseDest, PredBlock, BB); PBI->setSuccessor(1, FalseDest); } + if (NewWeights.size() == 2) { + // Halve the weights if any of them cannot fit in an uint32_t + FitWeights(NewWeights); + + SmallVector MDWeights(NewWeights.begin(),NewWeights.end()); + PBI->setMetadata(LLVMContext::MD_prof, + MDBuilder(BI->getContext()). + createBranchWeights(MDWeights)); + } else + PBI->setMetadata(LLVMContext::MD_prof, NULL); } else { // Update PHI nodes in the common successors. for (unsigned i = 0, e = PHIs.size(); i != e; ++i) { @@ -1981,90 +1996,6 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI) { // TODO: If BB is reachable from all paths through PredBlock, then we // could replace PBI's branch probabilities with BI's. - // Merge probability data into PredBlock's branch. - APInt A, B, C, D; - if (PBI->isConditional() && BI->isConditional() && - ExtractBranchMetadata(PBI, C, D) && ExtractBranchMetadata(BI, A, B)) { - // Given IR which does: - // bbA: - // br i1 %x, label %bbB, label %bbC - // bbB: - // br i1 %y, label %bbD, label %bbC - // Let's call the probability that we take the edge from %bbA to %bbB - // 'a', from %bbA to %bbC, 'b', from %bbB to %bbD 'c' and from %bbB to - // %bbC probability 'd'. - // - // We transform the IR into: - // bbA: - // br i1 %z, label %bbD, label %bbC - // where the probability of going to %bbD is (a*c) and going to bbC is - // (b+a*d). - // - // Probabilities aren't stored as ratios directly. Using branch weights, - // we get: - // (a*c)% = A*C, (b+(a*d))% = A*D+B*C+B*D. - - // In the event of overflow, we want to drop the LSB of the input - // probabilities. - unsigned BitsLost; - - // Ignore overflow result on ProbTrue. - APInt ProbTrue = MultiplyAndLosePrecision(A, C, B, D, BitsLost); - - APInt Tmp1 = MultiplyAndLosePrecision(B, D, A, C, BitsLost); - if (BitsLost) { - ProbTrue = ProbTrue.lshr(BitsLost*2); - } - - APInt Tmp2 = MultiplyAndLosePrecision(A, D, C, B, BitsLost); - if (BitsLost) { - ProbTrue = ProbTrue.lshr(BitsLost*2); - Tmp1 = Tmp1.lshr(BitsLost*2); - } - - APInt Tmp3 = MultiplyAndLosePrecision(B, C, A, D, BitsLost); - if (BitsLost) { - ProbTrue = ProbTrue.lshr(BitsLost*2); - Tmp1 = Tmp1.lshr(BitsLost*2); - Tmp2 = Tmp2.lshr(BitsLost*2); - } - - bool Overflow1 = false, Overflow2 = false; - APInt Tmp4 = Tmp2.uadd_ov(Tmp3, Overflow1); - APInt ProbFalse = Tmp4.uadd_ov(Tmp1, Overflow2); - - if (Overflow1 || Overflow2) { - ProbTrue = ProbTrue.lshr(1); - Tmp1 = Tmp1.lshr(1); - Tmp2 = Tmp2.lshr(1); - Tmp3 = Tmp3.lshr(1); - Tmp4 = Tmp2 + Tmp3; - ProbFalse = Tmp4 + Tmp1; - } - - // The sum of branch weights must fit in 32-bits. - if (ProbTrue.isNegative() && ProbFalse.isNegative()) { - ProbTrue = ProbTrue.lshr(1); - ProbFalse = ProbFalse.lshr(1); - } - - if (ProbTrue != ProbFalse) { - // Normalize the result. - APInt GCD = APIntOps::GreatestCommonDivisor(ProbTrue, ProbFalse); - ProbTrue = ProbTrue.udiv(GCD); - ProbFalse = ProbFalse.udiv(GCD); - - MDBuilder MDB(BI->getContext()); - MDNode *N = MDB.createBranchWeights(ProbTrue.getZExtValue(), - ProbFalse.getZExtValue()); - PBI->setMetadata(LLVMContext::MD_prof, N); - } else { - PBI->setMetadata(LLVMContext::MD_prof, NULL); - } - } else { - PBI->setMetadata(LLVMContext::MD_prof, NULL); - } - // Copy any debug value intrinsics into the end of PredBlock. for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) if (isa(*I)) diff --git a/test/Transforms/SimplifyCFG/preserve-branchweights.ll b/test/Transforms/SimplifyCFG/preserve-branchweights.ll index 93bbcfb38c3..1564e4e11d7 100644 --- a/test/Transforms/SimplifyCFG/preserve-branchweights.ll +++ b/test/Transforms/SimplifyCFG/preserve-branchweights.ll @@ -154,6 +154,26 @@ sw.epilog: ret void } +;; This test is based on test1 but swapped the targets of the second branch. +define void @test1_swap(i1 %a, i1 %b) { +; CHECK: @test1_swap +entry: + br i1 %a, label %Y, label %X, !prof !0 +; CHECK: br i1 %or.cond, label %Y, label %Z, !prof !4 + +X: + %c = or i1 %b, false + br i1 %c, label %Y, label %Z, !prof !1 + +Y: + call void @helper(i32 0) + ret void + +Z: + call void @helper(i32 1) + ret void +} + !0 = metadata !{metadata !"branch_weights", i32 3, i32 5} !1 = metadata !{metadata !"branch_weights", i32 1, i32 1} !2 = metadata !{metadata !"branch_weights", i32 1, i32 2} @@ -165,4 +185,5 @@ sw.epilog: ; CHECK: !1 = metadata !{metadata !"branch_weights", i32 1, i32 5} ; CHECK: !2 = metadata !{metadata !"branch_weights", i32 7, i32 1, i32 2} ; CHECK: !3 = metadata !{metadata !"branch_weights", i32 49, i32 12, i32 24, i32 35} -; CHECK-NOT: !4 +; CHECK: !4 = metadata !{metadata !"branch_weights", i32 11, i32 5} +; CHECK-NOT: !5