Generalize the cast-of-addrec folding to handle folding of SCEVs like
authorDan Gohman <gohman@apple.com>
Wed, 29 Apr 2009 22:28:28 +0000 (22:28 +0000)
committerDan Gohman <gohman@apple.com>
Wed, 29 Apr 2009 22:28:28 +0000 (22:28 +0000)
(sext i8 {-128,+,1} to i64) to i64 {-128,+,1}, where the iteration
crosses from negative to positive, but is still safe if the trip
count is within range.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@70421 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Analysis/ScalarEvolution.cpp
test/Analysis/ScalarEvolution/sext-iv-0.ll [new file with mode: 0644]
test/Analysis/ScalarEvolution/sext-iv-1.ll [new file with mode: 0644]

index b81df12b4c667e1e6041a217db310279d8ffb424..257186f4599b6a2b642df574f087f567b0b47fa9 100644 (file)
@@ -718,7 +718,7 @@ SCEVHandle ScalarEvolution::getZeroExtendExpr(const SCEVHandle &Op,
       SCEVHandle BECount = getBackedgeTakenCount(AR->getLoop());
       if (!isa<SCEVCouldNotCompute>(BECount)) {
         // Manually compute the final value for AR, checking for
-        // overflow at each step.
+        // overflow.
         SCEVHandle Start = AR->getStart();
         SCEVHandle Step = AR->getStepRecurrence(*this);
 
@@ -730,41 +730,34 @@ SCEVHandle ScalarEvolution::getZeroExtendExpr(const SCEVHandle &Op,
             getTruncateOrZeroExtend(CastedBECount, BECount->getType())) {
           const Type *WideTy =
             IntegerType::get(getTypeSizeInBits(Start->getType()) * 2);
+          // Check whether Start+Step*BECount has no unsigned overflow.
           SCEVHandle ZMul =
             getMulExpr(CastedBECount,
                        getTruncateOrZeroExtend(Step, Start->getType()));
-          // Check whether Start+Step*BECount has no unsigned overflow.
-          if (getZeroExtendExpr(ZMul, WideTy) ==
-              getMulExpr(getZeroExtendExpr(CastedBECount, WideTy),
-                         getZeroExtendExpr(Step, WideTy))) {
-            SCEVHandle Add = getAddExpr(Start, ZMul);
-            if (getZeroExtendExpr(Add, WideTy) ==
-                getAddExpr(getZeroExtendExpr(Start, WideTy),
-                           getZeroExtendExpr(ZMul, WideTy)))
-              // Return the expression with the addrec on the outside.
-              return getAddRecExpr(getZeroExtendExpr(Start, Ty),
-                                   getZeroExtendExpr(Step, Ty),
-                                   AR->getLoop());
-          }
+          SCEVHandle Add = getAddExpr(Start, ZMul);
+          if (getZeroExtendExpr(Add, WideTy) ==
+              getAddExpr(getZeroExtendExpr(Start, WideTy),
+                         getMulExpr(getZeroExtendExpr(CastedBECount, WideTy),
+                                    getZeroExtendExpr(Step, WideTy))))
+            // Return the expression with the addrec on the outside.
+            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
+                                 getZeroExtendExpr(Step, Ty),
+                                 AR->getLoop());
 
           // Similar to above, only this time treat the step value as signed.
           // This covers loops that count down.
           SCEVHandle SMul =
             getMulExpr(CastedBECount,
                        getTruncateOrSignExtend(Step, Start->getType()));
-          // Check whether Start+Step*BECount has no unsigned overflow.
-          if (getSignExtendExpr(SMul, WideTy) ==
-              getMulExpr(getZeroExtendExpr(CastedBECount, WideTy),
-                         getSignExtendExpr(Step, WideTy))) {
-            SCEVHandle Add = getAddExpr(Start, SMul);
-            if (getZeroExtendExpr(Add, WideTy) ==
-                getAddExpr(getZeroExtendExpr(Start, WideTy),
-                           getSignExtendExpr(SMul, WideTy)))
-              // Return the expression with the addrec on the outside.
-              return getAddRecExpr(getZeroExtendExpr(Start, Ty),
-                                   getSignExtendExpr(Step, Ty),
-                                   AR->getLoop());
-          }
+          Add = getAddExpr(Start, SMul);
+          if (getZeroExtendExpr(Add, WideTy) ==
+              getAddExpr(getZeroExtendExpr(Start, WideTy),
+                         getMulExpr(getZeroExtendExpr(CastedBECount, WideTy),
+                                    getSignExtendExpr(Step, WideTy))))
+            // Return the expression with the addrec on the outside.
+            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
+                                 getSignExtendExpr(Step, Ty),
+                                 AR->getLoop());
         }
       }
     }
@@ -807,37 +800,31 @@ SCEVHandle ScalarEvolution::getSignExtendExpr(const SCEVHandle &Op,
       SCEVHandle BECount = getBackedgeTakenCount(AR->getLoop());
       if (!isa<SCEVCouldNotCompute>(BECount)) {
         // Manually compute the final value for AR, checking for
-        // overflow at each step.
+        // overflow.
         SCEVHandle Start = AR->getStart();
         SCEVHandle Step = AR->getStepRecurrence(*this);
 
         // Check whether the backedge-taken count can be losslessly casted to
-        // the addrec's type. The count needs to be the same whether sign
-        // extended or zero extended.
+        // the addrec's type. The count is always unsigned.
         SCEVHandle CastedBECount =
           getTruncateOrZeroExtend(BECount, Start->getType());
         if (BECount ==
-            getTruncateOrZeroExtend(CastedBECount, BECount->getType()) &&
-            BECount ==
-            getTruncateOrSignExtend(CastedBECount, BECount->getType())) {
+            getTruncateOrZeroExtend(CastedBECount, BECount->getType())) {
           const Type *WideTy =
             IntegerType::get(getTypeSizeInBits(Start->getType()) * 2);
+          // Check whether Start+Step*BECount has no signed overflow.
           SCEVHandle SMul =
             getMulExpr(CastedBECount,
                        getTruncateOrSignExtend(Step, Start->getType()));
-          // Check whether Start+Step*BECount has no signed overflow.
-          if (getSignExtendExpr(SMul, WideTy) ==
-              getMulExpr(getSignExtendExpr(CastedBECount, WideTy),
-                         getSignExtendExpr(Step, WideTy))) {
-            SCEVHandle Add = getAddExpr(Start, SMul);
-            if (getSignExtendExpr(Add, WideTy) ==
-                getAddExpr(getSignExtendExpr(Start, WideTy),
-                           getSignExtendExpr(SMul, WideTy)))
-              // Return the expression with the addrec on the outside.
-              return getAddRecExpr(getSignExtendExpr(Start, Ty),
-                                   getSignExtendExpr(Step, Ty),
-                                   AR->getLoop());
-          }
+          SCEVHandle Add = getAddExpr(Start, SMul);
+          if (getSignExtendExpr(Add, WideTy) ==
+              getAddExpr(getSignExtendExpr(Start, WideTy),
+                         getMulExpr(getZeroExtendExpr(CastedBECount, WideTy),
+                                    getSignExtendExpr(Step, WideTy))))
+            // Return the expression with the addrec on the outside.
+            return getAddRecExpr(getSignExtendExpr(Start, Ty),
+                                 getSignExtendExpr(Step, Ty),
+                                 AR->getLoop());
         }
       }
     }
diff --git a/test/Analysis/ScalarEvolution/sext-iv-0.ll b/test/Analysis/ScalarEvolution/sext-iv-0.ll
new file mode 100644 (file)
index 0000000..4b2fcea
--- /dev/null
@@ -0,0 +1,31 @@
+; RUN: llvm-as < %s | opt -disable-output -scalar-evolution -analyze \
+; RUN:  | grep { -->  \{-128,+,1\}<bb1>                Exits: 127} | count 5
+
+; Convert (sext {-128,+,1}) to {sext(-128),+,sext(1)}, since the
+; trip count is within range where this is safe.
+
+target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128"
+target triple = "x86_64-unknown-linux-gnu"
+
+define void @foo(double* nocapture %x) nounwind {
+bb1.thread:
+       br label %bb1
+
+bb1:           ; preds = %bb1, %bb1.thread
+       %i.0.reg2mem.0 = phi i64 [ -128, %bb1.thread ], [ %8, %bb1 ]            ; <i64> [#uses=3]
+       %0 = trunc i64 %i.0.reg2mem.0 to i8             ; <i8> [#uses=1]
+       %1 = trunc i64 %i.0.reg2mem.0 to i9             ; <i8> [#uses=1]
+       %2 = sext i9 %1 to i64          ; <i64> [#uses=1]
+       %3 = getelementptr double* %x, i64 %2           ; <double*> [#uses=1]
+       %4 = load double* %3, align 8           ; <double> [#uses=1]
+       %5 = mul double %4, 3.900000e+00                ; <double> [#uses=1]
+       %6 = sext i8 %0 to i64          ; <i64> [#uses=1]
+       %7 = getelementptr double* %x, i64 %6           ; <double*> [#uses=1]
+       store double %5, double* %7, align 8
+       %8 = add i64 %i.0.reg2mem.0, 1          ; <i64> [#uses=2]
+       %9 = icmp sgt i64 %8, 127               ; <i1> [#uses=1]
+       br i1 %9, label %return, label %bb1
+
+return:                ; preds = %bb1
+       ret void
+}
diff --git a/test/Analysis/ScalarEvolution/sext-iv-1.ll b/test/Analysis/ScalarEvolution/sext-iv-1.ll
new file mode 100644 (file)
index 0000000..a9175c3
--- /dev/null
@@ -0,0 +1,100 @@
+; RUN: llvm-as < %s | opt -disable-output -scalar-evolution -analyze \
+; RUN:  | grep { -->  (sext i. \{.\*,+,.\*\}<bb1> to i64)} | count 5
+
+; Don't convert (sext {...,+,...}) to {sext(...),+,sext(...)} in cases
+; where the trip count is not within range.
+
+target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128"
+target triple = "x86_64-unknown-linux-gnu"
+
+define void @foo0(double* nocapture %x) nounwind {
+bb1.thread:
+       br label %bb1
+
+bb1:           ; preds = %bb1, %bb1.thread
+       %i.0.reg2mem.0 = phi i64 [ -128, %bb1.thread ], [ %8, %bb1 ]            ; <i64> [#uses=3]
+       %0 = trunc i64 %i.0.reg2mem.0 to i7             ; <i8> [#uses=1]
+       %1 = trunc i64 %i.0.reg2mem.0 to i9             ; <i8> [#uses=1]
+       %2 = sext i9 %1 to i64          ; <i64> [#uses=1]
+       %3 = getelementptr double* %x, i64 %2           ; <double*> [#uses=1]
+       %4 = load double* %3, align 8           ; <double> [#uses=1]
+       %5 = mul double %4, 3.900000e+00                ; <double> [#uses=1]
+       %6 = sext i7 %0 to i64          ; <i64> [#uses=1]
+       %7 = getelementptr double* %x, i64 %6           ; <double*> [#uses=1]
+       store double %5, double* %7, align 8
+       %8 = add i64 %i.0.reg2mem.0, 1          ; <i64> [#uses=2]
+       %9 = icmp sgt i64 %8, 127               ; <i1> [#uses=1]
+       br i1 %9, label %return, label %bb1
+
+return:                ; preds = %bb1
+       ret void
+}
+
+define void @foo1(double* nocapture %x) nounwind {
+bb1.thread:
+       br label %bb1
+
+bb1:           ; preds = %bb1, %bb1.thread
+       %i.0.reg2mem.0 = phi i64 [ -128, %bb1.thread ], [ %8, %bb1 ]            ; <i64> [#uses=3]
+       %0 = trunc i64 %i.0.reg2mem.0 to i8             ; <i8> [#uses=1]
+       %1 = trunc i64 %i.0.reg2mem.0 to i9             ; <i8> [#uses=1]
+       %2 = sext i9 %1 to i64          ; <i64> [#uses=1]
+       %3 = getelementptr double* %x, i64 %2           ; <double*> [#uses=1]
+       %4 = load double* %3, align 8           ; <double> [#uses=1]
+       %5 = mul double %4, 3.900000e+00                ; <double> [#uses=1]
+       %6 = sext i8 %0 to i64          ; <i64> [#uses=1]
+       %7 = getelementptr double* %x, i64 %6           ; <double*> [#uses=1]
+       store double %5, double* %7, align 8
+       %8 = add i64 %i.0.reg2mem.0, 1          ; <i64> [#uses=2]
+       %9 = icmp sgt i64 %8, 128               ; <i1> [#uses=1]
+       br i1 %9, label %return, label %bb1
+
+return:                ; preds = %bb1
+       ret void
+}
+
+define void @foo2(double* nocapture %x) nounwind {
+bb1.thread:
+       br label %bb1
+
+bb1:           ; preds = %bb1, %bb1.thread
+       %i.0.reg2mem.0 = phi i64 [ -129, %bb1.thread ], [ %8, %bb1 ]            ; <i64> [#uses=3]
+       %0 = trunc i64 %i.0.reg2mem.0 to i8             ; <i8> [#uses=1]
+       %1 = trunc i64 %i.0.reg2mem.0 to i9             ; <i8> [#uses=1]
+       %2 = sext i9 %1 to i64          ; <i64> [#uses=1]
+       %3 = getelementptr double* %x, i64 %2           ; <double*> [#uses=1]
+       %4 = load double* %3, align 8           ; <double> [#uses=1]
+       %5 = mul double %4, 3.900000e+00                ; <double> [#uses=1]
+       %6 = sext i8 %0 to i64          ; <i64> [#uses=1]
+       %7 = getelementptr double* %x, i64 %6           ; <double*> [#uses=1]
+       store double %5, double* %7, align 8
+       %8 = add i64 %i.0.reg2mem.0, 1          ; <i64> [#uses=2]
+       %9 = icmp sgt i64 %8, 127               ; <i1> [#uses=1]
+       br i1 %9, label %return, label %bb1
+
+return:                ; preds = %bb1
+       ret void
+}
+
+define void @foo3(double* nocapture %x) nounwind {
+bb1.thread:
+       br label %bb1
+
+bb1:           ; preds = %bb1, %bb1.thread
+       %i.0.reg2mem.0 = phi i64 [ -128, %bb1.thread ], [ %8, %bb1 ]            ; <i64> [#uses=3]
+       %0 = trunc i64 %i.0.reg2mem.0 to i8             ; <i8> [#uses=1]
+       %1 = trunc i64 %i.0.reg2mem.0 to i9             ; <i8> [#uses=1]
+       %2 = sext i9 %1 to i64          ; <i64> [#uses=1]
+       %3 = getelementptr double* %x, i64 %2           ; <double*> [#uses=1]
+       %4 = load double* %3, align 8           ; <double> [#uses=1]
+       %5 = mul double %4, 3.900000e+00                ; <double> [#uses=1]
+       %6 = sext i8 %0 to i64          ; <i64> [#uses=1]
+       %7 = getelementptr double* %x, i64 %6           ; <double*> [#uses=1]
+       store double %5, double* %7, align 8
+       %8 = add i64 %i.0.reg2mem.0, -1         ; <i64> [#uses=2]
+       %9 = icmp sgt i64 %8, 127               ; <i1> [#uses=1]
+       br i1 %9, label %return, label %bb1
+
+return:                ; preds = %bb1
+       ret void
+}