[Branch probability] Recompute branch weights of tail-merged basic blocks.
[oota-llvm.git] / lib / CodeGen / BranchFolding.h
index 0d15ed7e792a6a5aafabf9d320bcfbdc85a2eda7..66b152cde007f360d0df1a755b6c328ddb529e0f 100644 (file)
 
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/CodeGen/MachineBasicBlock.h"
+#include "llvm/Support/BlockFrequency.h"
 #include <vector>
 
 namespace llvm {
+  class MachineBlockFrequencyInfo;
+  class MachineBranchProbabilityInfo;
   class MachineFunction;
   class MachineModuleInfo;
   class RegScavenger;
@@ -23,7 +26,9 @@ namespace llvm {
 
   class BranchFolder {
   public:
-    explicit BranchFolder(bool defaultEnableTailMerge, bool CommonHoist);
+    explicit BranchFolder(bool defaultEnableTailMerge, bool CommonHoist,
+                          const MachineBlockFrequencyInfo &MBFI,
+                          const MachineBranchProbabilityInfo &MBPI);
 
     bool OptimizeFunction(MachineFunction &MF,
                           const TargetInstrInfo *tii,
@@ -92,9 +97,26 @@ namespace llvm {
     MachineModuleInfo *MMI;
     RegScavenger *RS;
 
+    /// \brief This class keeps track of branch frequencies of newly created
+    /// blocks and tail-merged blocks.
+    class MBFIWrapper {
+    public:
+      MBFIWrapper(const MachineBlockFrequencyInfo &I) : MBFI(I) {}
+      BlockFrequency getBlockFreq(const MachineBasicBlock *MBB) const;
+      void setBlockFreq(const MachineBasicBlock *MBB, BlockFrequency F);
+
+    private:
+      const MachineBlockFrequencyInfo &MBFI;
+      DenseMap<const MachineBasicBlock *, BlockFrequency> MergedBBFreq;
+    };
+
+    MBFIWrapper MBBFreqInfo;
+    const MachineBranchProbabilityInfo &MBPI;
+
     bool TailMergeBlocks(MachineFunction &MF);
     bool TryTailMergeBlocks(MachineBasicBlock* SuccBB,
                        MachineBasicBlock* PredBB);
+    void setCommonTailEdgeWeights(MachineBasicBlock &TailMBB);
     void MaintainLiveIns(MachineBasicBlock *CurMBB,
                          MachineBasicBlock *NewMBB);
     void ReplaceTailWithBranchTo(MachineBasicBlock::iterator OldInst,