Create a wrapper pass for BranchProbabilityInfo.
authorCong Hou <congh@google.com>
Wed, 15 Jul 2015 22:48:29 +0000 (22:48 +0000)
committerCong Hou <congh@google.com>
Wed, 15 Jul 2015 22:48:29 +0000 (22:48 +0000)
This new wrapper pass is useful when we want to do branch probability analysis conditionally (e.g. only in PGO mode) but don't want to add one more pass dependence.

http://reviews.llvm.org/D11241

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@242349 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/BranchProbabilityInfo.h
include/llvm/InitializePasses.h
lib/Analysis/Analysis.cpp
lib/Analysis/BlockFrequencyInfo.cpp
lib/Analysis/BranchProbabilityInfo.cpp
lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp

index 9d867567ba2942e70af9ce6705881d00f0252fc1..9f7c0f9d747e5af89eede0d7e522cac0f1c2dc48 100644 (file)
@@ -25,9 +25,9 @@ namespace llvm {
 class LoopInfo;
 class raw_ostream;
 
-/// \brief Analysis pass providing branch probability information.
+/// \brief Analysis providing branch probability information.
 ///
-/// This is a function analysis pass which provides information on the relative
+/// This is a function analysis which provides information on the relative
 /// probabilities of each "edge" in the function's CFG where such an edge is
 /// defined by a pair (PredBlock and an index in the successors). The
 /// probability of an edge from one block is always relative to the
@@ -37,20 +37,11 @@ class raw_ostream;
 /// identify an edge, since we can have multiple edges from Src to Dst.
 /// As an example, we can have a switch which jumps to Dst with value 0 and
 /// value 10.
-class BranchProbabilityInfo : public FunctionPass {
+class BranchProbabilityInfo {
 public:
-  static char ID;
-
-  BranchProbabilityInfo() : FunctionPass(ID) {
-    initializeBranchProbabilityInfoPass(*PassRegistry::getPassRegistry());
-  }
-
-  void getAnalysisUsage(AnalysisUsage &AU) const override;
-  bool runOnFunction(Function &F) override;
+  void releaseMemory();
 
-  void releaseMemory() override;
-
-  void print(raw_ostream &OS, const Module *M = nullptr) const override;
+  void print(raw_ostream &OS) const;
 
   /// \brief Get an edge's probability, relative to other out-edges of the Src.
   ///
@@ -118,6 +109,8 @@ public:
     return IsLikely ? (1u << 20) - 1 : 1;
   }
 
+  void calculate(Function &F, const LoopInfo& LI);
+
 private:
   // Since we allow duplicate edges from one basic block to another, we use
   // a pair (PredBlock and an index in the successors) to specify an edge.
@@ -152,12 +145,33 @@ private:
   bool calcMetadataWeights(BasicBlock *BB);
   bool calcColdCallHeuristics(BasicBlock *BB);
   bool calcPointerHeuristics(BasicBlock *BB);
-  bool calcLoopBranchHeuristics(BasicBlock *BB);
+  bool calcLoopBranchHeuristics(BasicBlock *BB, const LoopInfo &LI);
   bool calcZeroHeuristics(BasicBlock *BB);
   bool calcFloatingPointHeuristics(BasicBlock *BB);
   bool calcInvokeHeuristics(BasicBlock *BB);
 };
 
+/// \brief Legacy analysis pass which computes \c BranchProbabilityInfo.
+class BranchProbabilityInfoWrapperPass : public FunctionPass {
+  BranchProbabilityInfo BPI;
+
+public:
+  static char ID;
+
+  BranchProbabilityInfoWrapperPass() : FunctionPass(ID) {
+    initializeBranchProbabilityInfoWrapperPassPass(
+        *PassRegistry::getPassRegistry());
+  }
+
+  BranchProbabilityInfo &getBPI() { return BPI; }
+  const BranchProbabilityInfo &getBPI() const { return BPI; }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+  bool runOnFunction(Function &F) override;
+  void releaseMemory() override;
+  void print(raw_ostream &OS, const Module *M = nullptr) const override;
+};
+
 }
 
 #endif
index 6aadb24dad9c0fdead5aee798f25c4254efef991..8c27b34291346e96237ff382175fa2e9c5b7100c 100644 (file)
@@ -82,7 +82,7 @@ void initializeBlockExtractorPassPass(PassRegistry&);
 void initializeBlockFrequencyInfoWrapperPassPass(PassRegistry&);
 void initializeBoundsCheckingPass(PassRegistry&);
 void initializeBranchFolderPassPass(PassRegistry&);
-void initializeBranchProbabilityInfoPass(PassRegistry&);
+void initializeBranchProbabilityInfoWrapperPassPass(PassRegistry&);
 void initializeBreakCriticalEdgesPass(PassRegistry&);
 void initializeCallGraphPrinterPass(PassRegistry&);
 void initializeCallGraphViewerPass(PassRegistry&);
index c839b2d284d6d7364f41e46744bb5b89fceb28eb..3ce87f9a76c28e89806ad3f20bfa0ca2a2b98ba3 100644 (file)
@@ -28,7 +28,7 @@ void llvm::initializeAnalysis(PassRegistry &Registry) {
   initializeNoAAPass(Registry);
   initializeBasicAliasAnalysisPass(Registry);
   initializeBlockFrequencyInfoWrapperPassPass(Registry);
-  initializeBranchProbabilityInfoPass(Registry);
+  initializeBranchProbabilityInfoWrapperPassPass(Registry);
   initializeCostModelAnalysisPass(Registry);
   initializeCFGViewerPass(Registry);
   initializeCFGPrinterPass(Registry);
index be82f8cf1bcc6e5a4608252bcd0eff59d2e20a52..46095ffd1d072a0ed2267fc432c2a175616f3410 100644 (file)
@@ -162,7 +162,7 @@ void BlockFrequencyInfo::print(raw_ostream &OS) const {
 
 INITIALIZE_PASS_BEGIN(BlockFrequencyInfoWrapperPass, "block-freq",
                       "Block Frequency Analysis", true, true)
-INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfo)
+INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
 INITIALIZE_PASS_END(BlockFrequencyInfoWrapperPass, "block-freq",
                     "Block Frequency Analysis", true, true)
@@ -183,7 +183,7 @@ void BlockFrequencyInfoWrapperPass::print(raw_ostream &OS,
 }
 
 void BlockFrequencyInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
-  AU.addRequired<BranchProbabilityInfo>();
+  AU.addRequired<BranchProbabilityInfoWrapperPass>();
   AU.addRequired<LoopInfoWrapperPass>();
   AU.setPreservesAll();
 }
@@ -191,7 +191,8 @@ void BlockFrequencyInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
 void BlockFrequencyInfoWrapperPass::releaseMemory() { BFI.releaseMemory(); }
 
 bool BlockFrequencyInfoWrapperPass::runOnFunction(Function &F) {
-  BranchProbabilityInfo &BPI = getAnalysis<BranchProbabilityInfo>();
+  BranchProbabilityInfo &BPI =
+      getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
   LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
   BFI.calculate(F, BPI, LI);
   return false;
index 430b41241edf59bcd26d2247079096de27c41e5a..b813dca9369a029ccaba9d51bdd4800cceb03d52 100644 (file)
@@ -27,13 +27,13 @@ using namespace llvm;
 
 #define DEBUG_TYPE "branch-prob"
 
-INITIALIZE_PASS_BEGIN(BranchProbabilityInfo, "branch-prob",
+INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob",
                       "Branch Probability Analysis", false, true)
 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_END(BranchProbabilityInfo, "branch-prob",
+INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob",
                     "Branch Probability Analysis", false, true)
 
-char BranchProbabilityInfo::ID = 0;
+char BranchProbabilityInfoWrapperPass::ID = 0;
 
 // Weights are for internal use only. They are used by heuristics to help to
 // estimate edges' probability. Example:
@@ -319,8 +319,9 @@ bool BranchProbabilityInfo::calcPointerHeuristics(BasicBlock *BB) {
 
 // Calculate Edge Weights using "Loop Branch Heuristics". Predict backedges
 // as taken, exiting edges as not-taken.
-bool BranchProbabilityInfo::calcLoopBranchHeuristics(BasicBlock *BB) {
-  Loop *L = LI->getLoopFor(BB);
+bool BranchProbabilityInfo::calcLoopBranchHeuristics(BasicBlock *BB,
+                                                     const LoopInfo &LI) {
+  Loop *L = LI.getLoopFor(BB);
   if (!L)
     return false;
 
@@ -504,50 +505,11 @@ bool BranchProbabilityInfo::calcInvokeHeuristics(BasicBlock *BB) {
   return true;
 }
 
-void BranchProbabilityInfo::getAnalysisUsage(AnalysisUsage &AU) const {
-  AU.addRequired<LoopInfoWrapperPass>();
-  AU.setPreservesAll();
-}
-
-bool BranchProbabilityInfo::runOnFunction(Function &F) {
-  DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName()
-               << " ----\n\n");
-  LastF = &F; // Store the last function we ran on for printing.
-  LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
-  assert(PostDominatedByUnreachable.empty());
-  assert(PostDominatedByColdCall.empty());
-
-  // Walk the basic blocks in post-order so that we can build up state about
-  // the successors of a block iteratively.
-  for (auto BB : post_order(&F.getEntryBlock())) {
-    DEBUG(dbgs() << "Computing probabilities for " << BB->getName() << "\n");
-    if (calcUnreachableHeuristics(BB))
-      continue;
-    if (calcMetadataWeights(BB))
-      continue;
-    if (calcColdCallHeuristics(BB))
-      continue;
-    if (calcLoopBranchHeuristics(BB))
-      continue;
-    if (calcPointerHeuristics(BB))
-      continue;
-    if (calcZeroHeuristics(BB))
-      continue;
-    if (calcFloatingPointHeuristics(BB))
-      continue;
-    calcInvokeHeuristics(BB);
-  }
-
-  PostDominatedByUnreachable.clear();
-  PostDominatedByColdCall.clear();
-  return false;
-}
-
 void BranchProbabilityInfo::releaseMemory() {
   Weights.clear();
 }
 
-void BranchProbabilityInfo::print(raw_ostream &OS, const Module *) const {
+void BranchProbabilityInfo::print(raw_ostream &OS) const {
   OS << "---- Branch Probabilities ----\n";
   // We print the probabilities from the last function the analysis ran over,
   // or the function it is currently running over.
@@ -688,3 +650,54 @@ BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS,
 
   return OS;
 }
+
+void BranchProbabilityInfo::calculate(Function &F, const LoopInfo& LI) {
+  DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName()
+               << " ----\n\n");
+  LastF = &F; // Store the last function we ran on for printing.
+  assert(PostDominatedByUnreachable.empty());
+  assert(PostDominatedByColdCall.empty());
+
+  // Walk the basic blocks in post-order so that we can build up state about
+  // the successors of a block iteratively.
+  for (auto BB : post_order(&F.getEntryBlock())) {
+    DEBUG(dbgs() << "Computing probabilities for " << BB->getName() << "\n");
+    if (calcUnreachableHeuristics(BB))
+      continue;
+    if (calcMetadataWeights(BB))
+      continue;
+    if (calcColdCallHeuristics(BB))
+      continue;
+    if (calcLoopBranchHeuristics(BB, LI))
+      continue;
+    if (calcPointerHeuristics(BB))
+      continue;
+    if (calcZeroHeuristics(BB))
+      continue;
+    if (calcFloatingPointHeuristics(BB))
+      continue;
+    calcInvokeHeuristics(BB);
+  }
+
+  PostDominatedByUnreachable.clear();
+  PostDominatedByColdCall.clear();
+}
+
+void BranchProbabilityInfoWrapperPass::getAnalysisUsage(
+    AnalysisUsage &AU) const {
+  AU.addRequired<LoopInfoWrapperPass>();
+  AU.setPreservesAll();
+}
+
+bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) {
+  const LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+  BPI.calculate(F, LI);
+  return false;
+}
+
+void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); }
+
+void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS,
+                                             const Module *) const {
+  BPI.print(OS);
+}
index 97ece8b9248aa1651cbdac1f295e14f97a53b8c2..12111f50fe5ead454ad41480528dcaafc7e2aa6f 100644 (file)
@@ -351,7 +351,8 @@ SelectionDAGISel::SelectionDAGISel(TargetMachine &tm,
   DAGSize(0) {
     initializeGCModuleInfoPass(*PassRegistry::getPassRegistry());
     initializeAliasAnalysisAnalysisGroup(*PassRegistry::getPassRegistry());
-    initializeBranchProbabilityInfoPass(*PassRegistry::getPassRegistry());
+    initializeBranchProbabilityInfoWrapperPassPass(
+        *PassRegistry::getPassRegistry());
     initializeTargetLibraryInfoWrapperPassPass(
         *PassRegistry::getPassRegistry());
   }
@@ -369,7 +370,7 @@ void SelectionDAGISel::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.addPreserved<GCModuleInfo>();
   AU.addRequired<TargetLibraryInfoWrapperPass>();
   if (UseMBPI && OptLevel != CodeGenOpt::None)
-    AU.addRequired<BranchProbabilityInfo>();
+    AU.addRequired<BranchProbabilityInfoWrapperPass>();
   MachineFunctionPass::getAnalysisUsage(AU);
 }
 
@@ -449,7 +450,7 @@ bool SelectionDAGISel::runOnMachineFunction(MachineFunction &mf) {
   FuncInfo->set(Fn, *MF, CurDAG);
 
   if (UseMBPI && OptLevel != CodeGenOpt::None)
-    FuncInfo->BPI = &getAnalysis<BranchProbabilityInfo>();
+    FuncInfo->BPI = &getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
   else
     FuncInfo->BPI = nullptr;
 
index cbdacad8f28b50addb6ec83185e3829dc1ebb009..08fdcc38c045d8a48545756efe2f1f0930d82b42 100644 (file)
@@ -215,7 +215,7 @@ public:
     AU.addRequiredID(LoopSimplifyID);
     AU.addRequiredID(LCSSAID);
     AU.addRequired<ScalarEvolution>();
-    AU.addRequired<BranchProbabilityInfo>();
+    AU.addRequired<BranchProbabilityInfoWrapperPass>();
   }
 
   bool runOnLoop(Loop *L, LPPassManager &LPM) override;
@@ -1400,7 +1400,8 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) {
   InductiveRangeCheck::AllocatorTy IRCAlloc;
   SmallVector<InductiveRangeCheck *, 16> RangeChecks;
   ScalarEvolution &SE = getAnalysis<ScalarEvolution>();
-  BranchProbabilityInfo &BPI = getAnalysis<BranchProbabilityInfo>();
+  BranchProbabilityInfo &BPI =
+      getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
 
   for (auto BBI : L->getBlocks())
     if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator()))