From: Akira Hatanaka Date: Thu, 7 Aug 2014 19:30:13 +0000 (+0000) Subject: [Branch probability] Recompute branch weights of tail-merged basic blocks. X-Git-Url: http://demsky.eecs.uci.edu/git/?a=commitdiff_plain;h=70b56056a1ba8e945f2e75b67575920cb816ac69;p=oota-llvm.git [Branch probability] Recompute branch weights of tail-merged basic blocks. BranchFolderPass was not correctly setting the basic block branch weights when tail-merging created or merged blocks. This patch recomutes the weights of tail-merged blocks using the following formula: branch_weight(merged block to successor j) = sum(block_frequency(bb) * branch_probability(bb -> j)) bb is a block that is in the set of merged blocks. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@215135 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/CodeGen/BranchFolding.cpp b/lib/CodeGen/BranchFolding.cpp index 1b90ba03d49..cc446318678 100644 --- a/lib/CodeGen/BranchFolding.cpp +++ b/lib/CodeGen/BranchFolding.cpp @@ -20,6 +20,8 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/CodeGen/MachineBlockFrequencyInfo.h" +#include "llvm/CodeGen/MachineBranchProbabilityInfo.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineJumpTableInfo.h" #include "llvm/CodeGen/MachineModuleInfo.h" @@ -71,6 +73,8 @@ namespace { bool runOnMachineFunction(MachineFunction &MF) override; void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.addRequired(); AU.addRequired(); MachineFunctionPass::getAnalysisUsage(AU); } @@ -92,21 +96,24 @@ bool BranchFolderPass::runOnMachineFunction(MachineFunction &MF) { // HW that requires structurized CFG. bool EnableTailMerge = !MF.getTarget().requiresStructuredCFG() && PassConfig->getEnableTailMerge(); - BranchFolder Folder(EnableTailMerge, /*CommonHoist=*/true); + BranchFolder Folder(EnableTailMerge, /*CommonHoist=*/true, + getAnalysis(), + getAnalysis()); return Folder.OptimizeFunction(MF, MF.getSubtarget().getInstrInfo(), MF.getSubtarget().getRegisterInfo(), getAnalysisIfAvailable()); } - -BranchFolder::BranchFolder(bool defaultEnableTailMerge, bool CommonHoist) { +BranchFolder::BranchFolder(bool defaultEnableTailMerge, bool CommonHoist, + const MachineBlockFrequencyInfo &FreqInfo, + const MachineBranchProbabilityInfo &ProbInfo) + : EnableHoistCommonCode(CommonHoist), MBBFreqInfo(FreqInfo), + MBPI(ProbInfo) { switch (FlagEnableTailMerge) { case cl::BOU_UNSET: EnableTailMerge = defaultEnableTailMerge; break; case cl::BOU_TRUE: EnableTailMerge = true; break; case cl::BOU_FALSE: EnableTailMerge = false; break; } - - EnableHoistCommonCode = CommonHoist; } /// RemoveDeadBlock - Remove the specified dead machine basic block from the @@ -433,6 +440,9 @@ MachineBasicBlock *BranchFolder::SplitMBBAt(MachineBasicBlock &CurMBB, // Splice the code over. NewMBB->splice(NewMBB->end(), &CurMBB, BBI1, CurMBB.end()); + // NewMBB inherits CurMBB's block frequency. + MBBFreqInfo.setBlockFreq(NewMBB, MBBFreqInfo.getBlockFreq(&CurMBB)); + // For targets that use the register scavenger, we must maintain LiveIns. MaintainLiveIns(&CurMBB, NewMBB); @@ -502,6 +512,21 @@ BranchFolder::MergePotentialsElt::operator<(const MergePotentialsElt &o) const { #endif } +BlockFrequency +BranchFolder::MBFIWrapper::getBlockFreq(const MachineBasicBlock *MBB) const { + auto I = MergedBBFreq.find(MBB); + + if (I != MergedBBFreq.end()) + return I->second; + + return MBFI.getBlockFreq(MBB); +} + +void BranchFolder::MBFIWrapper::setBlockFreq(const MachineBasicBlock *MBB, + BlockFrequency F) { + MergedBBFreq[MBB] = F; +} + /// CountTerminators - Count the number of terminators in the given /// block and set I to the position of the first non-terminator, if there /// is one, or MBB->end() otherwise. @@ -804,6 +829,10 @@ bool BranchFolder::TryTailMergeBlocks(MachineBasicBlock *SuccBB, } MachineBasicBlock *MBB = SameTails[commonTailIndex].getBlock(); + + // Recompute commont tail MBB's edge weights and block frequency. + setCommonTailEdgeWeights(*MBB); + // MBB is common tail. Adjust all other BB's to jump to this one. // Traversal must be forwards so erases work. DEBUG(dbgs() << "\nUsing common tail in BB#" << MBB->getNumber() @@ -966,6 +995,44 @@ bool BranchFolder::TailMergeBlocks(MachineFunction &MF) { return MadeChange; } +void BranchFolder::setCommonTailEdgeWeights(MachineBasicBlock &TailMBB) { + SmallVector EdgeFreqLs(TailMBB.succ_size()); + BlockFrequency AccumulatedMBBFreq; + + // Aggregate edge frequency of successor edge j: + // edgeFreq(j) = sum (freq(bb) * edgeProb(bb, j)), + // where bb is a basic block that is in SameTails. + for (const auto &Src : SameTails) { + const MachineBasicBlock *SrcMBB = Src.getBlock(); + BlockFrequency BlockFreq = MBBFreqInfo.getBlockFreq(SrcMBB); + AccumulatedMBBFreq += BlockFreq; + + // It is not necessary to recompute edge weights if TailBB has less than two + // successors. + if (TailMBB.succ_size() <= 1) + continue; + + auto EdgeFreq = EdgeFreqLs.begin(); + + for (auto SuccI = TailMBB.succ_begin(), SuccE = TailMBB.succ_end(); + SuccI != SuccE; ++SuccI, ++EdgeFreq) + *EdgeFreq += BlockFreq * MBPI.getEdgeProbability(SrcMBB, *SuccI); + } + + MBBFreqInfo.setBlockFreq(&TailMBB, AccumulatedMBBFreq); + + if (TailMBB.succ_size() <= 1) + return; + + auto MaxEdgeFreq = *std::max_element(EdgeFreqLs.begin(), EdgeFreqLs.end()); + uint64_t Scale = MaxEdgeFreq.getFrequency() / UINT32_MAX + 1; + auto EdgeFreq = EdgeFreqLs.begin(); + + for (auto SuccI = TailMBB.succ_begin(), SuccE = TailMBB.succ_end(); + SuccI != SuccE; ++SuccI, ++EdgeFreq) + TailMBB.setSuccWeight(SuccI, EdgeFreq->getFrequency() / Scale); +} + //===----------------------------------------------------------------------===// // Branch Optimization //===----------------------------------------------------------------------===// diff --git a/lib/CodeGen/BranchFolding.h b/lib/CodeGen/BranchFolding.h index 0d15ed7e792..66b152cde00 100644 --- a/lib/CodeGen/BranchFolding.h +++ b/lib/CodeGen/BranchFolding.h @@ -12,9 +12,12 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/CodeGen/MachineBasicBlock.h" +#include "llvm/Support/BlockFrequency.h" #include namespace llvm { + class MachineBlockFrequencyInfo; + class MachineBranchProbabilityInfo; class MachineFunction; class MachineModuleInfo; class RegScavenger; @@ -23,7 +26,9 @@ namespace llvm { class BranchFolder { public: - explicit BranchFolder(bool defaultEnableTailMerge, bool CommonHoist); + explicit BranchFolder(bool defaultEnableTailMerge, bool CommonHoist, + const MachineBlockFrequencyInfo &MBFI, + const MachineBranchProbabilityInfo &MBPI); bool OptimizeFunction(MachineFunction &MF, const TargetInstrInfo *tii, @@ -92,9 +97,26 @@ namespace llvm { MachineModuleInfo *MMI; RegScavenger *RS; + /// \brief This class keeps track of branch frequencies of newly created + /// blocks and tail-merged blocks. + class MBFIWrapper { + public: + MBFIWrapper(const MachineBlockFrequencyInfo &I) : MBFI(I) {} + BlockFrequency getBlockFreq(const MachineBasicBlock *MBB) const; + void setBlockFreq(const MachineBasicBlock *MBB, BlockFrequency F); + + private: + const MachineBlockFrequencyInfo &MBFI; + DenseMap MergedBBFreq; + }; + + MBFIWrapper MBBFreqInfo; + const MachineBranchProbabilityInfo &MBPI; + bool TailMergeBlocks(MachineFunction &MF); bool TryTailMergeBlocks(MachineBasicBlock* SuccBB, MachineBasicBlock* PredBB); + void setCommonTailEdgeWeights(MachineBasicBlock &TailMBB); void MaintainLiveIns(MachineBasicBlock *CurMBB, MachineBasicBlock *NewMBB); void ReplaceTailWithBranchTo(MachineBasicBlock::iterator OldInst, diff --git a/lib/CodeGen/IfConversion.cpp b/lib/CodeGen/IfConversion.cpp index 79eb0fc6fe8..8e99de1541c 100644 --- a/lib/CodeGen/IfConversion.cpp +++ b/lib/CodeGen/IfConversion.cpp @@ -17,6 +17,7 @@ #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/CodeGen/LivePhysRegs.h" +#include "llvm/CodeGen/MachineBlockFrequencyInfo.h" #include "llvm/CodeGen/MachineBranchProbabilityInfo.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineInstrBuilder.h" @@ -161,6 +162,7 @@ namespace { const TargetLoweringBase *TLI; const TargetInstrInfo *TII; const TargetRegisterInfo *TRI; + const MachineBlockFrequencyInfo *MBFI; const MachineBranchProbabilityInfo *MBPI; MachineRegisterInfo *MRI; @@ -177,6 +179,7 @@ namespace { } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); AU.addRequired(); MachineFunctionPass::getAnalysisUsage(AU); } @@ -272,6 +275,7 @@ bool IfConverter::runOnMachineFunction(MachineFunction &MF) { TLI = MF.getSubtarget().getTargetLowering(); TII = MF.getSubtarget().getInstrInfo(); TRI = MF.getSubtarget().getRegisterInfo(); + MBFI = &getAnalysis(); MBPI = &getAnalysis(); MRI = &MF.getRegInfo(); @@ -286,7 +290,7 @@ bool IfConverter::runOnMachineFunction(MachineFunction &MF) { bool BFChange = false; if (!PreRegAlloc) { // Tail merge tend to expose more if-conversion opportunities. - BranchFolder BF(true, false); + BranchFolder BF(true, false, *MBFI, *MBPI); BFChange = BF.OptimizeFunction(MF, TII, MF.getSubtarget().getRegisterInfo(), getAnalysisIfAvailable()); } @@ -419,7 +423,7 @@ bool IfConverter::runOnMachineFunction(MachineFunction &MF) { BBAnalysis.clear(); if (MadeChange && IfCvtBranchFold) { - BranchFolder BF(false, false); + BranchFolder BF(false, false, *MBFI, *MBPI); BF.OptimizeFunction(MF, TII, MF.getSubtarget().getRegisterInfo(), getAnalysisIfAvailable()); } diff --git a/test/CodeGen/ARM/tail-merge-branch-weight.ll b/test/CodeGen/ARM/tail-merge-branch-weight.ll new file mode 100644 index 00000000000..9b5d566834f --- /dev/null +++ b/test/CodeGen/ARM/tail-merge-branch-weight.ll @@ -0,0 +1,44 @@ +; RUN: llc -mtriple=arm-apple-ios -print-machineinstrs=branch-folder \ +; RUN: %s -o /dev/null 2>&1 | FileCheck %s + +; Branch probability of tailed-merged block: +; +; p(L0_L1 -> L2) = p(entry -> L0) * p(L0 -> L2) + p(entry -> L1) * p(L1 -> L2) +; = 0.2 * 0.6 + 0.8 * 0.3 = 0.36 +; p(L0_L1 -> L3) = p(entry -> L0) * p(L0 -> L3) + p(entry -> L1) * p(L1 -> L3) +; = 0.2 * 0.4 + 0.8 * 0.7 = 0.64 + +; CHECK: # Machine code for function test0: +; CHECK: Successors according to CFG: BB#{{[0-9]+}}(13) BB#{{[0-9]+}}(24) +; CHECK: BB#{{[0-9]+}}: +; CHECK: BB#{{[0-9]+}}: +; CHECK: # End machine code for function test0. + +define i32 @test0(i32 %n, i32 %m, i32* nocapture %a, i32* nocapture %b) { +entry: + %cmp = icmp sgt i32 %n, 0 + br i1 %cmp, label %L0, label %L1, !prof !0 + +L0: ; preds = %entry + store i32 12, i32* %a, align 4 + store i32 18, i32* %b, align 4 + %cmp1 = icmp eq i32 %m, 8 + br i1 %cmp1, label %L2, label %L3, !prof !1 + +L1: ; preds = %entry + store i32 14, i32* %a, align 4 + store i32 18, i32* %b, align 4 + %cmp3 = icmp eq i32 %m, 8 + br i1 %cmp3, label %L2, label %L3, !prof !2 + +L2: ; preds = %L1, %L0 + br label %L3 + +L3: ; preds = %L0, %L1, %L2 + %retval.0 = phi i32 [ 100, %L2 ], [ 6, %L1 ], [ 6, %L0 ] + ret i32 %retval.0 +} + +!0 = metadata !{metadata !"branch_weights", i32 200, i32 800} +!1 = metadata !{metadata !"branch_weights", i32 600, i32 400} +!2 = metadata !{metadata !"branch_weights", i32 300, i32 700}