Make GetMinTrailingZeros a member function of ScalarEvolution,
authorDan Gohman <gohman@apple.com>
Fri, 19 Jun 2009 23:29:04 +0000 (23:29 +0000)
committerDan Gohman <gohman@apple.com>
Fri, 19 Jun 2009 23:29:04 +0000 (23:29 +0000)
so that it can access the TargetData member (when available) and
use ValueTracking.h information to compute information for
SCEVUnknown Values.

Also add GetMinLeadingZeros and GetMinSignBits functions,
with minimal implementations.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@73794 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolution.h
lib/Analysis/ScalarEvolution.cpp

index 8c64366c320e5b8e3831fde97721bfc3ddb05ad0..211c0ca5d1201558626a15cc046b8d10b5b24528 100644 (file)
@@ -571,6 +571,20 @@ namespace llvm {
     /// is deleted.
     void forgetLoopBackedgeTakenCount(const Loop *L);
 
+    /// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
+    /// guaranteed to end in (at every loop iteration).  It is, at the same time,
+    /// the minimum number of times S is divisible by 2.  For example, given {4,+,8}
+    /// it returns 2.  If S is guaranteed to be 0, it returns the bitwidth of S.
+    uint32_t GetMinTrailingZeros(const SCEVHandle &S);
+
+    /// GetMinLeadingZeros - Determine the minimum number of zero bits that S is
+    /// guaranteed to begin with (at every loop iteration).
+    uint32_t GetMinLeadingZeros(const SCEVHandle &S);
+
+    /// GetMinSignBits - Determine the minimum number of sign bits that S is
+    /// guaranteed to begin with.
+    uint32_t GetMinSignBits(const SCEVHandle &S);
+
     virtual bool runOnFunction(Function &F);
     virtual void releaseMemory();
     virtual void getAnalysisUsage(AnalysisUsage &AU) const;
index 049f886c5ab3af42a51e8789e79d78a98eaea691..2dab2f367de26ee730606e98fd1e1852fe1fbc8d 100644 (file)
@@ -2294,73 +2294,134 @@ SCEVHandle ScalarEvolution::createNodeForGEP(User *GEP) {
 /// guaranteed to end in (at every loop iteration).  It is, at the same time,
 /// the minimum number of times S is divisible by 2.  For example, given {4,+,8}
 /// it returns 2.  If S is guaranteed to be 0, it returns the bitwidth of S.
-static uint32_t GetMinTrailingZeros(SCEVHandle S, const ScalarEvolution &SE) {
+uint32_t
+ScalarEvolution::GetMinTrailingZeros(const SCEVHandle &S) {
   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
     return C->getValue()->getValue().countTrailingZeros();
 
   if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
-    return std::min(GetMinTrailingZeros(T->getOperand(), SE),
-                    (uint32_t)SE.getTypeSizeInBits(T->getType()));
+    return std::min(GetMinTrailingZeros(T->getOperand()),
+                    (uint32_t)getTypeSizeInBits(T->getType()));
 
   if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
-    uint32_t OpRes = GetMinTrailingZeros(E->getOperand(), SE);
-    return OpRes == SE.getTypeSizeInBits(E->getOperand()->getType()) ?
-             SE.getTypeSizeInBits(E->getType()) : OpRes;
+    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
+    return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
+             getTypeSizeInBits(E->getType()) : OpRes;
   }
 
   if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
-    uint32_t OpRes = GetMinTrailingZeros(E->getOperand(), SE);
-    return OpRes == SE.getTypeSizeInBits(E->getOperand()->getType()) ?
-             SE.getTypeSizeInBits(E->getType()) : OpRes;
+    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
+    return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
+             getTypeSizeInBits(E->getType()) : OpRes;
   }
 
   if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
     // The result is the min of all operands results.
-    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0), SE);
+    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
     for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
-      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i), SE));
+      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
     return MinOpRes;
   }
 
   if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
     // The result is the sum of all operands results.
-    uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0), SE);
-    uint32_t BitWidth = SE.getTypeSizeInBits(M->getType());
+    uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
+    uint32_t BitWidth = getTypeSizeInBits(M->getType());
     for (unsigned i = 1, e = M->getNumOperands();
          SumOpRes != BitWidth && i != e; ++i)
-      SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i), SE),
+      SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
                           BitWidth);
     return SumOpRes;
   }
 
   if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
     // The result is the min of all operands results.
-    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0), SE);
+    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
     for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
-      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i), SE));
+      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
     return MinOpRes;
   }
 
   if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
     // The result is the min of all operands results.
-    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0), SE);
+    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
     for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
-      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i), SE));
+      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
     return MinOpRes;
   }
 
   if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
     // The result is the min of all operands results.
-    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0), SE);
+    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
     for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
-      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i), SE));
+      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
     return MinOpRes;
   }
 
-  // SCEVUDivExpr, SCEVUnknown
+  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
+    // For a SCEVUnknown, ask ValueTracking.
+    unsigned BitWidth = getTypeSizeInBits(U->getType());
+    APInt Mask = APInt::getAllOnesValue(BitWidth);
+    APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
+    ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones);
+    return Zeros.countTrailingOnes();
+  }
+
+  // SCEVUDivExpr
   return 0;
 }
 
+uint32_t
+ScalarEvolution::GetMinLeadingZeros(const SCEVHandle &S) {
+  // TODO: Handle other SCEV expression types here.
+
+  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
+    return C->getValue()->getValue().countLeadingZeros();
+
+  if (const SCEVZeroExtendExpr *C = dyn_cast<SCEVZeroExtendExpr>(S)) {
+    // A zero-extension cast adds zero bits.
+    return GetMinLeadingZeros(C->getOperand()) +
+           (getTypeSizeInBits(C->getType()) -
+            getTypeSizeInBits(C->getOperand()->getType()));
+  }
+
+  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
+    // For a SCEVUnknown, ask ValueTracking.
+    unsigned BitWidth = getTypeSizeInBits(U->getType());
+    APInt Mask = APInt::getAllOnesValue(BitWidth);
+    APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
+    ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones, TD);
+    return Zeros.countLeadingOnes();
+  }
+
+  return 1;
+}
+
+uint32_t
+ScalarEvolution::GetMinSignBits(const SCEVHandle &S) {
+  // TODO: Handle other SCEV expression types here.
+
+  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
+    const APInt &A = C->getValue()->getValue();
+    return A.isNegative() ? A.countLeadingOnes() :
+                            A.countLeadingZeros();
+  }
+
+  if (const SCEVSignExtendExpr *C = dyn_cast<SCEVSignExtendExpr>(S)) {
+    // A sign-extension cast adds sign bits.
+    return GetMinSignBits(C->getOperand()) +
+           (getTypeSizeInBits(C->getType()) -
+            getTypeSizeInBits(C->getOperand()->getType()));
+  }
+
+  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
+    // For a SCEVUnknown, ask ValueTracking.
+    return ComputeNumSignBits(U->getValue(), TD);
+  }
+
+  return 1;
+}
+
 /// createSCEV - We know that there is no SCEV for the specified value.
 /// Analyze the expression.
 ///
@@ -2430,7 +2491,7 @@ SCEVHandle ScalarEvolution::createSCEV(Value *V) {
     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
       SCEVHandle LHS = getSCEV(U->getOperand(0));
       const APInt &CIVal = CI->getValue();
-      if (GetMinTrailingZeros(LHS, *this) >=
+      if (GetMinTrailingZeros(LHS) >=
           (CIVal.getBitWidth() - CIVal.countLeadingZeros()))
         return getAddExpr(LHS, getSCEV(U->getOperand(1)));
     }