Teach SCEV to handle more cases of 'and X, CST', specifically where CST is any number...
authorNick Lewycky <nicholas@mxc.ca>
Mon, 27 Jan 2014 10:04:03 +0000 (10:04 +0000)
committerNick Lewycky <nicholas@mxc.ca>
Mon, 27 Jan 2014 10:04:03 +0000 (10:04 +0000)
Unfortunately, this in turn led to some lower quality SCEVs due to some different paths through expression simplification, so add getUDivExactExpr and use it. This fixes all instances of the problems that I found, but we can make that function smarter as necessary.

Merge test "xor-and.ll" into "and-xor.ll" since I needed to update it anyways. Test 'nsw-offset.ll' analyzes a little deeper, %n now gets a scev in terms of %no instead of a SCEVUnknown.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@200203 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolution.h
lib/Analysis/ScalarEvolution.cpp
test/Analysis/ScalarEvolution/and-xor.ll
test/Analysis/ScalarEvolution/fold.ll
test/Analysis/ScalarEvolution/nsw-offset.ll
test/Analysis/ScalarEvolution/xor-and.ll [deleted file]

index e53a8bffd4e9b9a7a2da43d30ecbb1cecbd81dc9..80809da81ae47da6ce1b2c0be003c1421a0597e1 100644 (file)
@@ -624,6 +624,7 @@ namespace llvm {
       return getMulExpr(Ops, Flags);
     }
     const SCEV *getUDivExpr(const SCEV *LHS, const SCEV *RHS);
+    const SCEV *getUDivExactExpr(const SCEV *LHS, const SCEV *RHS);
     const SCEV *getAddRecExpr(const SCEV *Start, const SCEV *Step,
                               const Loop *L, SCEV::NoWrapFlags Flags);
     const SCEV *getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
index b65d99e4d67af5a322b0a1794c3ed6c175a6f3a0..e8ee46dff8ca9304370b486328f36a4722b7a1ad 100644 (file)
@@ -321,7 +321,7 @@ const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
   return S;
 }
 
-const SCEV *ScalarEvolution::getConstant(const APIntVal) {
+const SCEV *ScalarEvolution::getConstant(const APInt &Val) {
   return getConstant(ConstantInt::get(getContext(), Val));
 }
 
@@ -2239,6 +2239,75 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
   return S;
 }
 
+static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
+  APInt A = C1->getValue()->getValue().abs();
+  APInt B = C2->getValue()->getValue().abs();
+  uint32_t ABW = A.getBitWidth();
+  uint32_t BBW = B.getBitWidth();
+
+  if (ABW > BBW)
+    B = B.zext(ABW);
+  else if (ABW < BBW)
+    A = A.zext(BBW);
+
+  return APIntOps::GreatestCommonDivisor(A, B);
+}
+
+/// getUDivExactExpr - Get a canonical unsigned division expression, or
+/// something simpler if possible. There is no representation for an exact udiv
+/// in SCEV IR, but we can attempt to remove factors from the LHS and RHS.
+/// We can't do this when it's not exact because the udiv may be clearing bits.
+const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS,
+                                              const SCEV *RHS) {
+  // TODO: we could try to find factors in all sorts of things, but for now we
+  // just deal with u/exact (multiply, constant). See SCEVDivision towards the
+  // end of this file for inspiration.
+
+  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
+  if (!Mul)
+    return getUDivExpr(LHS, RHS);
+
+  if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
+    // If the mulexpr multiplies by a constant, then that constant must be the
+    // first element of the mulexpr.
+    if (const SCEVConstant *LHSCst =
+            dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
+      if (LHSCst == RHSCst) {
+        SmallVector<const SCEV *, 2> Operands;
+        Operands.append(Mul->op_begin() + 1, Mul->op_end());
+        return getMulExpr(Operands);
+      }
+
+      // We can't just assume that LHSCst divides RHSCst cleanly, it could be
+      // that there's a factor provided by one of the other terms. We need to
+      // check.
+      APInt Factor = gcd(LHSCst, RHSCst);
+      if (!Factor.isIntN(1)) {
+        LHSCst = cast<SCEVConstant>(
+            getConstant(LHSCst->getValue()->getValue().udiv(Factor)));
+        RHSCst = cast<SCEVConstant>(
+            getConstant(RHSCst->getValue()->getValue().udiv(Factor)));
+        SmallVector<const SCEV *, 2> Operands;
+        Operands.push_back(LHSCst);
+        Operands.append(Mul->op_begin() + 1, Mul->op_end());
+        LHS = getMulExpr(Operands);
+        RHS = RHSCst;
+        Mul = cast<SCEVMulExpr>(LHS);
+      }
+    }
+  }
+
+  for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
+    if (Mul->getOperand(i) == RHS) {
+      SmallVector<const SCEV *, 2> Operands;
+      Operands.append(Mul->op_begin(), Mul->op_begin() + i);
+      Operands.append(Mul->op_begin() + i + 1, Mul->op_end());
+      return getMulExpr(Operands);
+    }
+  }
+
+  return getUDivExpr(LHS, RHS);
+}
 
 /// getAddRecExpr - Get an add recurrence expression for the specified loop.
 /// Simplify the expression as much as possible.
@@ -3689,17 +3758,24 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
       // Use ComputeMaskedBits to compute what ShrinkDemandedConstant
       // knew about to reconstruct a low-bits mask value.
       unsigned LZ = A.countLeadingZeros();
+      unsigned TZ = A.countTrailingZeros();
       unsigned BitWidth = A.getBitWidth();
       APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
       ComputeMaskedBits(U->getOperand(0), KnownZero, KnownOne, TD);
 
-      APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ);
-
-      if (LZ != 0 && !((~A & ~KnownZero) & EffectiveMask))
-        return
-          getZeroExtendExpr(getTruncateExpr(getSCEV(U->getOperand(0)),
-                                IntegerType::get(getContext(), BitWidth - LZ)),
-                            U->getType());
+      APInt EffectiveMask =
+          APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
+      if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) {
+        const SCEV *MulCount = getConstant(
+            ConstantInt::get(getContext(), APInt::getOneBitSet(BitWidth, TZ)));
+        return getMulExpr(
+            getZeroExtendExpr(
+                getTruncateExpr(
+                    getUDivExactExpr(getSCEV(U->getOperand(0)), MulCount),
+                    IntegerType::get(getContext(), BitWidth - LZ - TZ)),
+                U->getType()),
+            MulCount);
+      }
     }
     break;
 
@@ -6692,20 +6768,6 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
   return SE.getCouldNotCompute();
 }
 
-static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
-  APInt A = C1->getValue()->getValue().abs();
-  APInt B = C2->getValue()->getValue().abs();
-  uint32_t ABW = A.getBitWidth();
-  uint32_t BBW = B.getBitWidth();
-
-  if (ABW > BBW)
-    B = B.zext(ABW);
-  else if (ABW < BBW)
-    A = A.zext(BBW);
-
-  return APIntOps::GreatestCommonDivisor(A, B);
-}
-
 static const APInt srem(const SCEVConstant *C1, const SCEVConstant *C2) {
   APInt A = C1->getValue()->getValue();
   APInt B = C2->getValue()->getValue();
index 404ab91e269d2e71b293956821701a5e035a157d..ad636da4d4d79f193b4a666768511e50938c098f 100644 (file)
@@ -1,11 +1,27 @@
 ; RUN: opt < %s -scalar-evolution -analyze | FileCheck %s
 
+; CHECK-LABEL: @test1
 ; CHECK: -->  (zext
 ; CHECK: -->  (zext
 ; CHECK-NOT: -->  (zext
 
-define i32 @foo(i32 %x) {
+define i32 @test1(i32 %x) {
   %n = and i32 %x, 255
   %y = xor i32 %n, 255
   ret i32 %y
 }
+
+; ScalarEvolution shouldn't try to analyze %z into something like
+;   -->  (zext i4 (-1 + (-1 * (trunc i64 (8 * %x) to i4))) to i64)
+; or
+;   -->  (8 * (zext i1 (trunc i64 ((8 * %x) /u 8) to i1) to i64))
+
+; CHECK-LABEL: @test2
+; CHECK: -->  (8 * (zext i1 (trunc i64 %x to i1) to i64))
+
+define i64 @test2(i64 %x) {
+  %a = shl i64 %x, 3
+  %t = and i64 %a, 8
+  %z = xor i64 %t, 8
+  ret i64 %z
+}
index 57006dd9bb42c469cc6653a29bc0c37e9a105d88..84b657050c535ef8bcb9b9713a617cb6777555f0 100644 (file)
@@ -60,3 +60,20 @@ loop:
 exit:
   ret void
 }
+
+define void @test5(i32 %i) {
+; CHECK-LABEL: @test5
+  %A = and i32 %i, 1
+; CHECK: -->  (zext i1 (trunc i32 %i to i1) to i32)
+  %B = and i32 %i, 2
+; CHECK: -->  (2 * (zext i1 (trunc i32 (%i /u 2) to i1) to i32))
+  %C = and i32 %i, 63
+; CHECK: -->  (zext i6 (trunc i32 %i to i6) to i32)
+  %D = and i32 %i, 126
+; CHECK: -->  (2 * (zext i6 (trunc i32 (%i /u 2) to i6) to i32))
+  %E = and i32 %i, 64
+; CHECK: -->  (64 * (zext i1 (trunc i32 (%i /u 64) to i1) to i32))
+  %F = and i32 %i, -2147483648
+; CHECK: -->  (-2147483648 * (%i /u -2147483648))
+  ret void
+}
index 8969a5ad4ceb21b0b520637f321b0c5c995336b2..88cdcf23d9ed0a64c42dd06eccf1779958eab015 100644 (file)
@@ -73,5 +73,5 @@ return:                                           ; preds = %bb1.return_crit_edg
   ret void
 }
 
-; CHECK: Loop %bb: backedge-taken count is ((-1 + %n) /u 2)
+; CHECK: Loop %bb: backedge-taken count is ((-1 + (2 * (%no /u 2))) /u 2)
 ; CHECK: Loop %bb: max backedge-taken count is 1073741822
diff --git a/test/Analysis/ScalarEvolution/xor-and.ll b/test/Analysis/ScalarEvolution/xor-and.ll
deleted file mode 100644 (file)
index 2616ea9..0000000
+++ /dev/null
@@ -1,13 +0,0 @@
-; RUN: opt < %s -scalar-evolution -analyze | FileCheck %s
-
-; ScalarEvolution shouldn't try to analyze %z into something like
-;   -->  (zext i4 (-1 + (-1 * (trunc i64 (8 * %x) to i4))) to i64)
-
-; CHECK: -->  (zext i4 (-8 + (trunc i64 (8 * %x) to i4)) to i64)
-
-define i64 @foo(i64 %x) {
-  %a = shl i64 %x, 3
-  %t = and i64 %a, 8
-  %z = xor i64 %t, 8
-  ret i64 %z
-}