Minor code simplification.
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index 0c17f14892bc15999049607b486a1fbeabc23110..04ddc4e1a7cee7e62cc3a7f8196cebe0d2872513 100644 (file)
@@ -66,6 +66,7 @@
 #include "llvm/GlobalVariable.h"
 #include "llvm/Instructions.h"
 #include "llvm/LLVMContext.h"
+#include "llvm/Operator.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/Dominators.h"
 #include "llvm/Analysis/LoopInfo.h"
@@ -191,12 +192,13 @@ const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
 }
 
 const SCEV *ScalarEvolution::getConstant(const APInt& Val) {
-  return getConstant(ConstantInt::get(Val));
+  return getConstant(Context->getConstantInt(Val));
 }
 
 const SCEV *
 ScalarEvolution::getConstant(const Type *Ty, uint64_t V, bool isSigned) {
-  return getConstant(ConstantInt::get(cast<IntegerType>(Ty), V, isSigned));
+  return getConstant(
+    Context->getConstantInt(cast<IntegerType>(Ty), V, isSigned));
 }
 
 const Type *SCEVConstant::getType() const { return V->getType(); }
@@ -959,6 +961,22 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
             return getAddRecExpr(getSignExtendExpr(Start, Ty),
                                  getSignExtendExpr(Step, Ty),
                                  L);
+
+          // Similar to above, only this time treat the step value as unsigned.
+          // This covers loops that count up with an unsigned step.
+          const SCEV *UMul =
+            getMulExpr(CastedMaxBECount,
+                       getTruncateOrZeroExtend(Step, Start->getType()));
+          Add = getAddExpr(Start, UMul);
+          OperandExtendedAdd =
+            getAddExpr(getZeroExtendExpr(Start, WideTy),
+                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
+                                  getZeroExtendExpr(Step, WideTy)));
+          if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd)
+            // Return the expression with the addrec on the outside.
+            return getAddRecExpr(getSignExtendExpr(Start, Ty),
+                                 getZeroExtendExpr(Step, Ty),
+                                 L);
         }
 
         // If the backedge is guarded by a comparison with the pre-inc value
@@ -1500,7 +1518,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops) {
     ++Idx;
     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
       // We found two constants, fold them together!
-      ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() *
+      ConstantInt *Fold = Context->getConstantInt(LHSC->getValue()->getValue() *
                                            RHSC->getValue()->getValue());
       Ops[0] = getConstant(Fold);
       Ops.erase(Ops.begin()+1);  // Erase the folded element
@@ -1851,7 +1869,7 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
     assert(Idx < Ops.size());
     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
       // We found two constants, fold them together!
-      ConstantInt *Fold = ConstantInt::get(
+      ConstantInt *Fold = Context->getConstantInt(
                               APIntOps::smax(LHSC->getValue()->getValue(),
                                              RHSC->getValue()->getValue()));
       Ops[0] = getConstant(Fold);
@@ -1948,7 +1966,7 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
     assert(Idx < Ops.size());
     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
       // We found two constants, fold them together!
-      ConstantInt *Fold = ConstantInt::get(
+      ConstantInt *Fold = Context->getConstantInt(
                               APIntOps::umax(LHSC->getValue()->getValue(),
                                              RHSC->getValue()->getValue()));
       Ops[0] = getConstant(Fold);
@@ -2115,7 +2133,7 @@ const SCEV *ScalarEvolution::getSCEV(Value *V) {
 /// specified signed integer value and return a SCEV for the constant.
 const SCEV *ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) {
   const IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
-  return getConstant(ConstantInt::get(ITy, Val));
+  return getConstant(Context->getConstantInt(ITy, Val));
 }
 
 /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
@@ -2413,7 +2431,7 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
 /// createNodeForGEP - Expand GEP instructions into add and multiply
 /// operations. This allows them to be analyzed by regular SCEV code.
 ///
-const SCEV *ScalarEvolution::createNodeForGEP(User *GEP) {
+const SCEV *ScalarEvolution::createNodeForGEP(Operator *GEP) {
 
   const Type *IntPtrTy = TD->getIntPtrType();
   Value *Base = GEP->getOperand(0);
@@ -2614,7 +2632,7 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) {
         APInt Max = APIntOps::umax(StartRange.getUnsignedMax(),
                                    EndRange.getUnsignedMax());
         if (Min.isMinValue() && Max.isMaxValue())
-          return ConstantRange(Min.getBitWidth(), /*isFullSet=*/true);
+          return FullSet;
         return ConstantRange(Min, Max+1);
       }
     }
@@ -2626,7 +2644,9 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) {
     APInt Mask = APInt::getAllOnesValue(BitWidth);
     APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
     ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones, TD);
-    return ConstantRange(Ones, ~Zeros);
+    if (Ones == ~Zeros + 1)
+      return FullSet;
+    return ConstantRange(Ones, ~Zeros + 1);
   }
 
   return FullSet;
@@ -2762,7 +2782,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
   else
     return getUnknown(V);
 
-  User *U = cast<User>(V);
+  Operator *U = cast<Operator>(V);
   switch (Opcode) {
   case Instruction::Add:
     return getAddExpr(getSCEV(U->getOperand(0)),
@@ -2870,7 +2890,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
     // Turn shift left of a constant amount into a multiply.
     if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
       uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
-      Constant *X = ConstantInt::get(
+      Constant *X = Context->getConstantInt(
         APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
       return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
     }
@@ -2880,7 +2900,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
     // Turn logical shift right of a constant into a unsigned divide.
     if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
       uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
-      Constant *X = ConstantInt::get(
+      Constant *X = Context->getConstantInt(
         APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
       return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X));
     }
@@ -2920,15 +2940,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
       return getSCEV(U->getOperand(0));
     break;
 
-  case Instruction::IntToPtr:
-    if (!TD) break; // Without TD we can't analyze pointers.
-    return getTruncateOrZeroExtend(getSCEV(U->getOperand(0)),
-                                   TD->getIntPtrType());
-
-  case Instruction::PtrToInt:
-    if (!TD) break; // Without TD we can't analyze pointers.
-    return getTruncateOrZeroExtend(getSCEV(U->getOperand(0)),
-                                   U->getType());
+    // It's tempting to handle inttoptr and ptrtoint, however this can
+    // lead to pointer expressions which cannot be expanded to GEPs
+    // (because they may overflow). For now, the only pointer-typed
+    // expressions we handle are GEPs and address literals.
 
   case Instruction::GetElementPtr:
     if (!TD) break; // Without TD we can't analyze pointers.
@@ -3537,8 +3552,8 @@ ScalarEvolution::ComputeLoadConstantCompareBackedgeTakenCount(
 
   unsigned MaxSteps = MaxBruteForceIterations;
   for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) {
-    ConstantInt *ItCst =
-      ConstantInt::get(cast<IntegerType>(IdxExpr->getType()), IterationNum);
+    ConstantInt *ItCst = Context->getConstantInt(
+                           cast<IntegerType>(IdxExpr->getType()), IterationNum);
     ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this);
 
     // Form the GEP offset.
@@ -4247,7 +4262,7 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
 
   switch (Pred) {
   default:
-    assert(0 && "Unexpected ICmpInst::Predicate value!");
+    llvm_unreachable("Unexpected ICmpInst::Predicate value!");
     break;
   case ICmpInst::ICMP_SGT:
     Pred = ICmpInst::ICMP_SLT;
@@ -4555,23 +4570,32 @@ ScalarEvolution::isNecessaryCondOperands(ICmpInst::Predicate Pred,
                                          const SCEV *FoundLHS,
                                          const SCEV *FoundRHS) {
   switch (Pred) {
-  default: break;
+  default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
+  case ICmpInst::ICMP_EQ:
+  case ICmpInst::ICMP_NE:
+    if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
+      return true;
+    break;
   case ICmpInst::ICMP_SLT:
+  case ICmpInst::ICMP_SLE:
     if (isKnownPredicate(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
         isKnownPredicate(ICmpInst::ICMP_SGE, RHS, FoundRHS))
       return true;
     break;
   case ICmpInst::ICMP_SGT:
+  case ICmpInst::ICMP_SGE:
     if (isKnownPredicate(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
         isKnownPredicate(ICmpInst::ICMP_SLE, RHS, FoundRHS))
       return true;
     break;
   case ICmpInst::ICMP_ULT:
+  case ICmpInst::ICMP_ULE:
     if (isKnownPredicate(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
         isKnownPredicate(ICmpInst::ICMP_UGE, RHS, FoundRHS))
       return true;
     break;
   case ICmpInst::ICMP_UGT:
+  case ICmpInst::ICMP_UGE:
     if (isKnownPredicate(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
         isKnownPredicate(ICmpInst::ICMP_ULE, RHS, FoundRHS))
       return true;