ScalarEvolution: Analyze trip count of loops with a switch guarding the exit.
authorBenjamin Kramer <benny.kra@googlemail.com>
Tue, 11 Feb 2014 15:44:32 +0000 (15:44 +0000)
committerBenjamin Kramer <benny.kra@googlemail.com>
Tue, 11 Feb 2014 15:44:32 +0000 (15:44 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@201159 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolution.h
lib/Analysis/ScalarEvolution.cpp
test/Analysis/ScalarEvolution/trip-count-switch.ll [new file with mode: 0644]

index 80809da81ae47da6ce1b2c0be003c1421a0597e1..36119acf4038992038c3950cfa1b8e6faf82c93b 100644 (file)
@@ -469,6 +469,13 @@ namespace llvm {
                                        BasicBlock *FBB,
                                        bool IsSubExpr);
 
+    /// ComputeExitLimitFromSingleExitSwitch - Compute the number of times the
+    /// backedge of the specified loop will execute if its exit condition were a
+    /// switch with a single exiting case to ExitingBB.
+    ExitLimit
+    ComputeExitLimitFromSingleExitSwitch(const Loop *L, SwitchInst *Switch,
+                               BasicBlock *ExitingBB, bool IsSubExpr);
+
     /// ComputeLoadConstantCompareExitLimit - Given an exit condition
     /// of 'icmp op load X, cst', try to see if we can compute the
     /// backedge-taken count.
index 1a15144863dc302c190fdac64be938769aa38c24..ec7e2211a86adfe79ed9ad451347812bafede374 100644 (file)
@@ -4453,12 +4453,19 @@ ScalarEvolution::ExitLimit
 ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) {
 
   // Okay, we've chosen an exiting block.  See what condition causes us to
-  // exit at this block.
-  //
-  // FIXME: we should be able to handle switch instructions (with a single exit)
-  BranchInst *ExitBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
-  if (ExitBr == 0) return getCouldNotCompute();
-  assert(ExitBr->isConditional() && "If unconditional, it can't be in loop!");
+  // exit at this block and remember the exit block and whether all other targets
+  // lead to the loop header.
+  bool MustExecuteLoopHeader = true;
+  BasicBlock *Exit = 0;
+  for (succ_iterator SI = succ_begin(ExitingBlock), SE = succ_end(ExitingBlock);
+       SI != SE; ++SI)
+    if (!L->contains(*SI)) {
+      if (Exit) // Multiple exit successors.
+        return getCouldNotCompute();
+      Exit = *SI;
+    } else if (*SI != L->getHeader()) {
+      MustExecuteLoopHeader = false;
+    }
 
   // At this point, we know we have a conditional branch that determines whether
   // the loop is exited.  However, we don't know if the branch is executed each
@@ -4477,13 +4484,11 @@ ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) {
   //
   //  More extensive analysis could be done to handle more cases here.
   //
-  if (ExitBr->getSuccessor(0) != L->getHeader() &&
-      ExitBr->getSuccessor(1) != L->getHeader() &&
-      ExitBr->getParent() != L->getHeader()) {
+  if (!MustExecuteLoopHeader && ExitingBlock != L->getHeader()) {
     // The simple checks failed, try climbing the unique predecessor chain
     // up to the header.
     bool Ok = false;
-    for (BasicBlock *BB = ExitBr->getParent(); BB; ) {
+    for (BasicBlock *BB = ExitingBlock; BB; ) {
       BasicBlock *Pred = BB->getUniquePredecessor();
       if (!Pred)
         return getCouldNotCompute();
@@ -4507,11 +4512,20 @@ ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) {
       return getCouldNotCompute();
   }
 
-  // Proceed to the next level to examine the exit condition expression.
-  return ComputeExitLimitFromCond(L, ExitBr->getCondition(),
-                                  ExitBr->getSuccessor(0),
-                                  ExitBr->getSuccessor(1),
-                                  /*IsSubExpr=*/false);
+  TerminatorInst *Term = ExitingBlock->getTerminator();
+  if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
+    assert(BI->isConditional() && "If unconditional, it can't be in loop!");
+    // Proceed to the next level to examine the exit condition expression.
+    return ComputeExitLimitFromCond(L, BI->getCondition(), BI->getSuccessor(0),
+                                    BI->getSuccessor(1),
+                                    /*IsSubExpr=*/false);
+  }
+
+  if (SwitchInst *SI = dyn_cast<SwitchInst>(Term))
+    return ComputeExitLimitFromSingleExitSwitch(L, SI, Exit,
+                                                /*IsSubExpr=*/false);
+
+  return getCouldNotCompute();
 }
 
 /// ComputeExitLimitFromCond - Compute the number of times the
@@ -4728,6 +4742,30 @@ ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L,
   return ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
 }
 
+ScalarEvolution::ExitLimit
+ScalarEvolution::ComputeExitLimitFromSingleExitSwitch(const Loop *L,
+                                                      SwitchInst *Switch,
+                                                      BasicBlock *ExitingBlock,
+                                                      bool IsSubExpr) {
+  assert(!L->contains(ExitingBlock) && "Not an exiting block!");
+
+  // Give up if the exit is the default dest of a switch.
+  if (Switch->getDefaultDest() == ExitingBlock)
+    return getCouldNotCompute();
+
+  assert(L->contains(Switch->getDefaultDest()) &&
+         "Default case must not exit the loop!");
+  const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
+  const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
+
+  // while (X != Y) --> while (X-Y != 0)
+  ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, IsSubExpr);
+  if (EL.hasAnyInfo())
+    return EL;
+
+  return getCouldNotCompute();
+}
+
 static ConstantInt *
 EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
                                 ScalarEvolution &SE) {
diff --git a/test/Analysis/ScalarEvolution/trip-count-switch.ll b/test/Analysis/ScalarEvolution/trip-count-switch.ll
new file mode 100644 (file)
index 0000000..2d2b6b4
--- /dev/null
@@ -0,0 +1,30 @@
+; RUN: opt < %s -analyze -scalar-evolution | FileCheck %s
+
+declare void @foo()
+
+define void @test1() nounwind {
+entry:
+  br label %for.cond
+
+for.cond:                                         ; preds = %if.end, %entry
+  %i.0 = phi i32 [ 2, %entry ], [ %dec, %if.end ]
+  switch i32 %i.0, label %if.end [
+    i32 0, label %for.end
+    i32 1, label %if.then
+  ]
+
+if.then:                                          ; preds = %for.cond
+  tail call void @foo()
+  br label %if.end
+
+if.end:                                           ; preds = %for.cond, %if.then
+  %dec = add nsw i32 %i.0, -1
+  br label %for.cond
+
+for.end:                                          ; preds = %for.cond
+  ret void
+
+; CHECK-LABEL: @test1
+; CHECK: Loop %for.cond: backedge-taken count is 2
+; CHECK: Loop %for.cond: max backedge-taken count is 2
+}