Enhance induction variable code to remove the
authorDale Johannesen <dalej@apple.com>
Wed, 15 Apr 2009 01:10:12 +0000 (01:10 +0000)
committerDale Johannesen <dalej@apple.com>
Wed, 15 Apr 2009 01:10:12 +0000 (01:10 +0000)
sext around sext(shorter IV + constant), using a
longer IV instead, when it can figure out the
add can't overflow.  This comes up a lot in
subscripting; mainly affects 64 bit.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@69123 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/IndVarSimplify.cpp
test/Transforms/IndVarSimplify/2009-04-14-shorten_iv_vars.ll [new file with mode: 0644]

index 205efca6ce244b4f6496b510edd18cb1773fea9e..20d389b1bf8ee1930cf0abeff08944fd2aa48465 100644 (file)
@@ -467,8 +467,12 @@ static const Type *getEffectiveIndvarType(const PHINode *Phi) {
 /// whether an induction variable in the same type that starts
 /// at 0 would undergo signed overflow.
 ///
-/// In addition to setting the NoSignedWrap, and NoUnsignedWrap,
-/// variables, return the PHI for this induction variable.
+/// In addition to setting the NoSignedWrap and NoUnsignedWrap
+/// variables to true when appropriate (they are not set to false here),
+/// return the PHI for this induction variable.  Also record the initial
+/// and final values and the increment; these are not meaningful unless
+/// either NoSignedWrap or NoUnsignedWrap is true, and are always meaningful
+/// in that case, although the final value may be 0 indicating a nonconstant.
 ///
 /// TODO: This duplicates a fair amount of ScalarEvolution logic.
 /// Perhaps this can be merged with
@@ -479,7 +483,10 @@ static const PHINode *TestOrigIVForWrap(const Loop *L,
                                         const BranchInst *BI,
                                         const Instruction *OrigCond,
                                         bool &NoSignedWrap,
-                                        bool &NoUnsignedWrap) {
+                                        bool &NoUnsignedWrap,
+                                        const ConstantInt* &InitialVal,
+                                        const ConstantInt* &IncrVal,
+                                        const ConstantInt* &LimitVal) {
   // Verify that the loop is sane and find the exit condition.
   const ICmpInst *Cmp = dyn_cast<ICmpInst>(OrigCond);
   if (!Cmp) return 0;
@@ -542,31 +549,31 @@ static const PHINode *TestOrigIVForWrap(const Loop *L,
   // Get the increment instruction. Look past casts if we will
   // be able to prove that the original induction variable doesn't
   // undergo signed or unsigned overflow, respectively.
-  const Value *IncrVal = CmpLHS;
+  const Value *IncrInst = CmpLHS;
   if (isSigned) {
     if (const SExtInst *SI = dyn_cast<SExtInst>(CmpLHS)) {
       if (!isa<ConstantInt>(CmpRHS) ||
           !cast<ConstantInt>(CmpRHS)->getValue()
-            .isSignedIntN(IncrVal->getType()->getPrimitiveSizeInBits()))
+            .isSignedIntN(IncrInst->getType()->getPrimitiveSizeInBits()))
         return 0;
-      IncrVal = SI->getOperand(0);
+      IncrInst = SI->getOperand(0);
     }
   } else {
     if (const ZExtInst *ZI = dyn_cast<ZExtInst>(CmpLHS)) {
       if (!isa<ConstantInt>(CmpRHS) ||
           !cast<ConstantInt>(CmpRHS)->getValue()
-            .isIntN(IncrVal->getType()->getPrimitiveSizeInBits()))
+            .isIntN(IncrInst->getType()->getPrimitiveSizeInBits()))
         return 0;
-      IncrVal = ZI->getOperand(0);
+      IncrInst = ZI->getOperand(0);
     }
   }
 
   // For now, only analyze induction variables that have simple increments.
-  const BinaryOperator *IncrOp = dyn_cast<BinaryOperator>(IncrVal);
-  if (!IncrOp ||
-      IncrOp->getOpcode() != Instruction::Add ||
-      !isa<ConstantInt>(IncrOp->getOperand(1)) ||
-      !cast<ConstantInt>(IncrOp->getOperand(1))->equalsInt(1))
+  const BinaryOperator *IncrOp = dyn_cast<BinaryOperator>(IncrInst);
+  if (!IncrOp || IncrOp->getOpcode() != Instruction::Add)
+    return 0;
+  IncrVal = dyn_cast<ConstantInt>(IncrOp->getOperand(1));
+  if (!IncrVal)
     return 0;
 
   // Make sure the PHI looks like a normal IV.
@@ -584,21 +591,78 @@ static const PHINode *TestOrigIVForWrap(const Loop *L,
   // For now, only analyze loops with a constant start value, so that
   // we can easily determine if the start value is not a maximum value
   // which would wrap on the first iteration.
-  const ConstantInt *InitialVal =
-    dyn_cast<ConstantInt>(PN->getIncomingValue(IncomingEdge));
+  InitialVal = dyn_cast<ConstantInt>(PN->getIncomingValue(IncomingEdge));
   if (!InitialVal)
     return 0;
 
-  // The original induction variable will start at some non-max value,
-  // it counts up by one, and the loop iterates only while it remans
-  // less than some value in the same type. As such, it will never wrap.
+  // The upper limit need not be a constant; we'll check later.
+  LimitVal = dyn_cast<ConstantInt>(CmpRHS);
+
+  // We detect the impossibility of wrapping in two cases, both of
+  // which require starting with a non-max value:
+  // - The IV counts up by one, and the loop iterates only while it remains
+  // less than a limiting value (any) in the same type.
+  // - The IV counts up by a positive increment other than 1, and the
+  // constant limiting value + the increment is less than the max value
+  // (computed as max-increment to avoid overflow)
   if (isSigned && !InitialVal->getValue().isMaxSignedValue()) {
-    NoSignedWrap = true;
-  } else if (!isSigned && !InitialVal->getValue().isMaxValue())
-    NoUnsignedWrap = true;
+    if (IncrVal->equalsInt(1))
+      NoSignedWrap = true;    // LimitVal need not be constant
+    else if (LimitVal) {
+      uint64_t numBits = LimitVal->getValue().getBitWidth();
+      if (IncrVal->getValue().sgt(APInt::getNullValue(numBits)) &&
+          (APInt::getSignedMaxValue(numBits) - IncrVal->getValue())
+            .sgt(LimitVal->getValue()))
+        NoSignedWrap = true;
+    }
+  } else if (!isSigned && !InitialVal->getValue().isMaxValue()) {
+    if (IncrVal->equalsInt(1))
+      NoUnsignedWrap = true;  // LimitVal need not be constant
+    else if (LimitVal) {
+      uint64_t numBits = LimitVal->getValue().getBitWidth();
+      if (IncrVal->getValue().ugt(APInt::getNullValue(numBits)) &&
+          (APInt::getMaxValue(numBits) - IncrVal->getValue())
+            .ugt(LimitVal->getValue()))
+        NoUnsignedWrap = true;
+    }
+  }
   return PN;
 }
 
+static Value *getSignExtendedTruncVar(const SCEVAddRecExpr *AR,
+                                      ScalarEvolution *SE,
+                                      const Type *LargestType, Loop *L, 
+                                      const Type *myType,
+                                      SCEVExpander &Rewriter, 
+                                      BasicBlock::iterator InsertPt) {
+  SCEVHandle ExtendedStart =
+    SE->getSignExtendExpr(AR->getStart(), LargestType);
+  SCEVHandle ExtendedStep =
+    SE->getSignExtendExpr(AR->getStepRecurrence(*SE), LargestType);
+  SCEVHandle ExtendedAddRec =
+    SE->getAddRecExpr(ExtendedStart, ExtendedStep, L);
+  if (LargestType != myType)
+    ExtendedAddRec = SE->getTruncateExpr(ExtendedAddRec, myType);
+  return Rewriter.expandCodeFor(ExtendedAddRec, InsertPt);
+}
+
+static Value *getZeroExtendedTruncVar(const SCEVAddRecExpr *AR,
+                                      ScalarEvolution *SE,
+                                      const Type *LargestType, Loop *L, 
+                                      const Type *myType,
+                                      SCEVExpander &Rewriter, 
+                                      BasicBlock::iterator InsertPt) {
+  SCEVHandle ExtendedStart =
+    SE->getZeroExtendExpr(AR->getStart(), LargestType);
+  SCEVHandle ExtendedStep =
+    SE->getZeroExtendExpr(AR->getStepRecurrence(*SE), LargestType);
+  SCEVHandle ExtendedAddRec =
+    SE->getAddRecExpr(ExtendedStart, ExtendedStep, L);
+  if (LargestType != myType)
+    ExtendedAddRec = SE->getTruncateExpr(ExtendedAddRec, myType);
+  return Rewriter.expandCodeFor(ExtendedAddRec, InsertPt);
+}
+
 bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) {
   LI = &getAnalysis<LoopInfo>();
   SE = &getAnalysis<ScalarEvolution>();
@@ -680,6 +744,7 @@ bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) {
   // using it.  We can currently only handle loops with a single exit.
   bool NoSignedWrap = false;
   bool NoUnsignedWrap = false;
+  const ConstantInt* InitialVal, * IncrVal, * LimitVal;
   const PHINode *OrigControllingPHI = 0;
   if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && ExitingBlock)
     // Can't rewrite non-branch yet.
@@ -688,7 +753,8 @@ bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) {
         // Determine if the OrigIV will ever undergo overflow.
         OrigControllingPHI =
           TestOrigIVForWrap(L, BI, OrigCond,
-                            NoSignedWrap, NoUnsignedWrap);
+                            NoSignedWrap, NoUnsignedWrap,
+                            InitialVal, IncrVal, LimitVal);
 
         // We'll be replacing the original condition, so it'll be dead.
         DeadInsts.insert(OrigCond);
@@ -733,29 +799,44 @@ bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) {
       for (Value::use_iterator UI = PN->use_begin(), UE = PN->use_end();
            UI != UE; ++UI) {
         if (isa<SExtInst>(UI) && NoSignedWrap) {
-          SCEVHandle ExtendedStart =
-            SE->getSignExtendExpr(AR->getStart(), LargestType);
-          SCEVHandle ExtendedStep =
-            SE->getSignExtendExpr(AR->getStepRecurrence(*SE), LargestType);
-          SCEVHandle ExtendedAddRec =
-            SE->getAddRecExpr(ExtendedStart, ExtendedStep, L);
-          if (LargestType != UI->getType())
-            ExtendedAddRec = SE->getTruncateExpr(ExtendedAddRec, UI->getType());
-          Value *TruncIndVar = Rewriter.expandCodeFor(ExtendedAddRec, InsertPt);
+          Value *TruncIndVar = getSignExtendedTruncVar(AR, SE, LargestType, L, 
+                                            UI->getType(), Rewriter, InsertPt);
           UI->replaceAllUsesWith(TruncIndVar);
           if (Instruction *DeadUse = dyn_cast<Instruction>(*UI))
             DeadInsts.insert(DeadUse);
         }
+        // See if we can figure out sext(i+constant) doesn't wrap, so we can
+        // use a larger add.  This is common in subscripting.
+        Instruction *UInst = dyn_cast<Instruction>(*UI);
+        if (UInst && UInst->getOpcode()==Instruction::Add &&
+            UInst->hasOneUse() &&
+            isa<ConstantInt>(UInst->getOperand(1)) &&
+            isa<SExtInst>(UInst->use_begin()) && NoSignedWrap && LimitVal) {
+          uint64_t numBits = LimitVal->getValue().getBitWidth();
+          ConstantInt* RHS = dyn_cast<ConstantInt>(UInst->getOperand(1));
+          if (((APInt::getSignedMaxValue(numBits) - IncrVal->getValue()) -
+                RHS->getValue()).sgt(LimitVal->getValue())) {
+            SExtInst* oldSext = dyn_cast<SExtInst>(UInst->use_begin());
+            Value *TruncIndVar = getSignExtendedTruncVar(AR, SE, LargestType, L,
+                                              oldSext->getType(), Rewriter,
+                                              InsertPt);
+            APInt APcopy = APInt(RHS->getValue());
+            ConstantInt* newRHS = 
+                  ConstantInt::get(APcopy.sext(oldSext->getType()->
+                                               getPrimitiveSizeInBits()));
+            Value *NewAdd = BinaryOperator::CreateAdd(TruncIndVar, newRHS,
+                                                      UInst->getName()+".nosex",
+                                                      UInst);
+            oldSext->replaceAllUsesWith(NewAdd);
+            if (Instruction *DeadUse = dyn_cast<Instruction>(oldSext))
+              DeadInsts.insert(DeadUse);
+            if (Instruction *DeadUse = dyn_cast<Instruction>(UInst))
+              DeadInsts.insert(DeadUse);
+          }
+        }
         if (isa<ZExtInst>(UI) && NoUnsignedWrap) {
-          SCEVHandle ExtendedStart =
-            SE->getZeroExtendExpr(AR->getStart(), LargestType);
-          SCEVHandle ExtendedStep =
-            SE->getZeroExtendExpr(AR->getStepRecurrence(*SE), LargestType);
-          SCEVHandle ExtendedAddRec =
-            SE->getAddRecExpr(ExtendedStart, ExtendedStep, L);
-          if (LargestType != UI->getType())
-            ExtendedAddRec = SE->getTruncateExpr(ExtendedAddRec, UI->getType());
-          Value *TruncIndVar = Rewriter.expandCodeFor(ExtendedAddRec, InsertPt);
+          Value *TruncIndVar = getZeroExtendedTruncVar(AR, SE, LargestType, L, 
+                                            UI->getType(), Rewriter, InsertPt);
           UI->replaceAllUsesWith(TruncIndVar);
           if (Instruction *DeadUse = dyn_cast<Instruction>(*UI))
             DeadInsts.insert(DeadUse);
diff --git a/test/Transforms/IndVarSimplify/2009-04-14-shorten_iv_vars.ll b/test/Transforms/IndVarSimplify/2009-04-14-shorten_iv_vars.ll
new file mode 100644 (file)
index 0000000..134c9c7
--- /dev/null
@@ -0,0 +1,114 @@
+; RUN: llvm-as < %s | opt -indvars | llvm-dis | not grep {sext}
+; ModuleID = '<stdin>'
+target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128"
+target triple = "x86_64-apple-darwin9.6"
+@a = external global i32*              ; <i32**> [#uses=3]
+@b = external global i32*              ; <i32**> [#uses=3]
+@c = external global i32*              ; <i32**> [#uses=3]
+@d = external global i32*              ; <i32**> [#uses=3]
+@e = external global i32*              ; <i32**> [#uses=3]
+@f = external global i32*              ; <i32**> [#uses=3]
+
+define void @foo() nounwind {
+bb1.thread:
+       br label %bb1
+
+bb1:           ; preds = %bb1, %bb1.thread
+       %i.0.reg2mem.0 = phi i32 [ 0, %bb1.thread ], [ %84, %bb1 ]              ; <i32> [#uses=19]
+       %0 = load i32** @a, align 8             ; <i32*> [#uses=1]
+       %1 = load i32** @b, align 8             ; <i32*> [#uses=1]
+       %2 = sext i32 %i.0.reg2mem.0 to i64             ; <i64> [#uses=1]
+       %3 = getelementptr i32* %1, i64 %2              ; <i32*> [#uses=1]
+       %4 = load i32* %3, align 1              ; <i32> [#uses=1]
+       %5 = load i32** @c, align 8             ; <i32*> [#uses=1]
+       %6 = sext i32 %i.0.reg2mem.0 to i64             ; <i64> [#uses=1]
+       %7 = getelementptr i32* %5, i64 %6              ; <i32*> [#uses=1]
+       %8 = load i32* %7, align 1              ; <i32> [#uses=1]
+       %9 = add i32 %8, %4             ; <i32> [#uses=1]
+       %10 = sext i32 %i.0.reg2mem.0 to i64            ; <i64> [#uses=1]
+       %11 = getelementptr i32* %0, i64 %10            ; <i32*> [#uses=1]
+       store i32 %9, i32* %11, align 1
+       %12 = load i32** @a, align 8            ; <i32*> [#uses=1]
+       %13 = add i32 %i.0.reg2mem.0, 1         ; <i32> [#uses=1]
+       %14 = load i32** @b, align 8            ; <i32*> [#uses=1]
+       %15 = add i32 %i.0.reg2mem.0, 1         ; <i32> [#uses=1]
+       %16 = sext i32 %15 to i64               ; <i64> [#uses=1]
+       %17 = getelementptr i32* %14, i64 %16           ; <i32*> [#uses=1]
+       %18 = load i32* %17, align 1            ; <i32> [#uses=1]
+       %19 = load i32** @c, align 8            ; <i32*> [#uses=1]
+       %20 = add i32 %i.0.reg2mem.0, 1         ; <i32> [#uses=1]
+       %21 = sext i32 %20 to i64               ; <i64> [#uses=1]
+       %22 = getelementptr i32* %19, i64 %21           ; <i32*> [#uses=1]
+       %23 = load i32* %22, align 1            ; <i32> [#uses=1]
+       %24 = add i32 %23, %18          ; <i32> [#uses=1]
+       %25 = sext i32 %13 to i64               ; <i64> [#uses=1]
+       %26 = getelementptr i32* %12, i64 %25           ; <i32*> [#uses=1]
+       store i32 %24, i32* %26, align 1
+       %27 = load i32** @a, align 8            ; <i32*> [#uses=1]
+       %28 = add i32 %i.0.reg2mem.0, 2         ; <i32> [#uses=1]
+       %29 = load i32** @b, align 8            ; <i32*> [#uses=1]
+       %30 = add i32 %i.0.reg2mem.0, 2         ; <i32> [#uses=1]
+       %31 = sext i32 %30 to i64               ; <i64> [#uses=1]
+       %32 = getelementptr i32* %29, i64 %31           ; <i32*> [#uses=1]
+       %33 = load i32* %32, align 1            ; <i32> [#uses=1]
+       %34 = load i32** @c, align 8            ; <i32*> [#uses=1]
+       %35 = add i32 %i.0.reg2mem.0, 2         ; <i32> [#uses=1]
+       %36 = sext i32 %35 to i64               ; <i64> [#uses=1]
+       %37 = getelementptr i32* %34, i64 %36           ; <i32*> [#uses=1]
+       %38 = load i32* %37, align 1            ; <i32> [#uses=1]
+       %39 = add i32 %38, %33          ; <i32> [#uses=1]
+       %40 = sext i32 %28 to i64               ; <i64> [#uses=1]
+       %41 = getelementptr i32* %27, i64 %40           ; <i32*> [#uses=1]
+       store i32 %39, i32* %41, align 1
+       %42 = load i32** @d, align 8            ; <i32*> [#uses=1]
+       %43 = load i32** @e, align 8            ; <i32*> [#uses=1]
+       %44 = sext i32 %i.0.reg2mem.0 to i64            ; <i64> [#uses=1]
+       %45 = getelementptr i32* %43, i64 %44           ; <i32*> [#uses=1]
+       %46 = load i32* %45, align 1            ; <i32> [#uses=1]
+       %47 = load i32** @f, align 8            ; <i32*> [#uses=1]
+       %48 = sext i32 %i.0.reg2mem.0 to i64            ; <i64> [#uses=1]
+       %49 = getelementptr i32* %47, i64 %48           ; <i32*> [#uses=1]
+       %50 = load i32* %49, align 1            ; <i32> [#uses=1]
+       %51 = add i32 %50, %46          ; <i32> [#uses=1]
+       %52 = sext i32 %i.0.reg2mem.0 to i64            ; <i64> [#uses=1]
+       %53 = getelementptr i32* %42, i64 %52           ; <i32*> [#uses=1]
+       store i32 %51, i32* %53, align 1
+       %54 = load i32** @d, align 8            ; <i32*> [#uses=1]
+       %55 = add i32 %i.0.reg2mem.0, 1         ; <i32> [#uses=1]
+       %56 = load i32** @e, align 8            ; <i32*> [#uses=1]
+       %57 = add i32 %i.0.reg2mem.0, 1         ; <i32> [#uses=1]
+       %58 = sext i32 %57 to i64               ; <i64> [#uses=1]
+       %59 = getelementptr i32* %56, i64 %58           ; <i32*> [#uses=1]
+       %60 = load i32* %59, align 1            ; <i32> [#uses=1]
+       %61 = load i32** @f, align 8            ; <i32*> [#uses=1]
+       %62 = add i32 %i.0.reg2mem.0, 1         ; <i32> [#uses=1]
+       %63 = sext i32 %62 to i64               ; <i64> [#uses=1]
+       %64 = getelementptr i32* %61, i64 %63           ; <i32*> [#uses=1]
+       %65 = load i32* %64, align 1            ; <i32> [#uses=1]
+       %66 = add i32 %65, %60          ; <i32> [#uses=1]
+       %67 = sext i32 %55 to i64               ; <i64> [#uses=1]
+       %68 = getelementptr i32* %54, i64 %67           ; <i32*> [#uses=1]
+       store i32 %66, i32* %68, align 1
+       %69 = load i32** @d, align 8            ; <i32*> [#uses=1]
+       %70 = add i32 %i.0.reg2mem.0, 2         ; <i32> [#uses=1]
+       %71 = load i32** @e, align 8            ; <i32*> [#uses=1]
+       %72 = add i32 %i.0.reg2mem.0, 2         ; <i32> [#uses=1]
+       %73 = sext i32 %72 to i64               ; <i64> [#uses=1]
+       %74 = getelementptr i32* %71, i64 %73           ; <i32*> [#uses=1]
+       %75 = load i32* %74, align 1            ; <i32> [#uses=1]
+       %76 = load i32** @f, align 8            ; <i32*> [#uses=1]
+       %77 = add i32 %i.0.reg2mem.0, 2         ; <i32> [#uses=1]
+       %78 = sext i32 %77 to i64               ; <i64> [#uses=1]
+       %79 = getelementptr i32* %76, i64 %78           ; <i32*> [#uses=1]
+       %80 = load i32* %79, align 1            ; <i32> [#uses=1]
+       %81 = add i32 %80, %75          ; <i32> [#uses=1]
+       %82 = sext i32 %70 to i64               ; <i64> [#uses=1]
+       %83 = getelementptr i32* %69, i64 %82           ; <i32*> [#uses=1]
+       store i32 %81, i32* %83, align 1
+       %84 = add i32 %i.0.reg2mem.0, 1         ; <i32> [#uses=2]
+       %85 = icmp sgt i32 %84, 23646           ; <i1> [#uses=1]
+       br i1 %85, label %return, label %bb1
+
+return:                ; preds = %bb1
+       ret void
+}