Add a RegisterClassInfo class that lazily caches information about
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index eef41554f4fbe7b1096388cde5cf8c93b1e60eaa..025718e09febeb32cbcf0fa28122136194c9b555 100644 (file)
@@ -1035,6 +1035,93 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
   return S;
 }
 
+// Get the limit of a recurrence such that incrementing by Step cannot cause
+// signed overflow as long as the value of the recurrence within the loop does
+// not exceed this limit before incrementing.
+static const SCEV *getOverflowLimitForStep(const SCEV *Step,
+                                           ICmpInst::Predicate *Pred,
+                                           ScalarEvolution *SE) {
+  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
+  if (SE->isKnownPositive(Step)) {
+    *Pred = ICmpInst::ICMP_SLT;
+    return SE->getConstant(APInt::getSignedMinValue(BitWidth) -
+                           SE->getSignedRange(Step).getSignedMax());
+  }
+  if (SE->isKnownNegative(Step)) {
+    *Pred = ICmpInst::ICMP_SGT;
+    return SE->getConstant(APInt::getSignedMaxValue(BitWidth) -
+                       SE->getSignedRange(Step).getSignedMin());
+  }
+  return 0;
+}
+
+// The recurrence AR has been shown to have no signed wrap. Typically, if we can
+// prove NSW for AR, then we can just as easily prove NSW for its preincrement
+// or postincrement sibling. This allows normalizing a sign extended AddRec as
+// such: {sext(Step + Start),+,Step} => {(Step + sext(Start),+,Step} As a
+// result, the expression "Step + sext(PreIncAR)" is congruent with
+// "sext(PostIncAR)"
+static const SCEV *getPreStartForSignExtend(const SCEVAddRecExpr *AR,
+                                            const Type *Ty,
+                                            ScalarEvolution *SE) {
+  const Loop *L = AR->getLoop();
+  const SCEV *Start = AR->getStart();
+  const SCEV *Step = AR->getStepRecurrence(*SE);
+
+  // Check for a simple looking step prior to loop entry.
+  const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
+  if (!SA || SA->getNumOperands() != 2 || SA->getOperand(0) != Step)
+    return 0;
+
+  // This is a postinc AR. Check for overflow on the preinc recurrence using the
+  // same three conditions that getSignExtendedExpr checks.
+
+  // 1. NSW flags on the step increment.
+  const SCEV *PreStart = SA->getOperand(1);
+  const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
+    SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
+
+  if (PreAR && PreAR->getNoWrapFlags(SCEV::FlagNSW))
+    return PreStart;
+
+  // 2. Direct overflow check on the step operation's expression.
+  unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
+  const Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
+  const SCEV *OperandExtendedStart =
+    SE->getAddExpr(SE->getSignExtendExpr(PreStart, WideTy),
+                   SE->getSignExtendExpr(Step, WideTy));
+  if (SE->getSignExtendExpr(Start, WideTy) == OperandExtendedStart) {
+    // Cache knowledge of PreAR NSW.
+    if (PreAR)
+      const_cast<SCEVAddRecExpr *>(PreAR)->setNoWrapFlags(SCEV::FlagNSW);
+    // FIXME: this optimization needs a unit test
+    DEBUG(dbgs() << "SCEV: untested prestart overflow check\n");
+    return PreStart;
+  }
+
+  // 3. Loop precondition.
+  ICmpInst::Predicate Pred;
+  const SCEV *OverflowLimit = getOverflowLimitForStep(Step, &Pred, SE);
+
+  if (OverflowLimit &&
+      SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) {
+    return PreStart;
+  }
+  return 0;
+}
+
+// Get the normalized sign-extended expression for this AddRec's Start.
+static const SCEV *getSignExtendAddRecStart(const SCEVAddRecExpr *AR,
+                                            const Type *Ty,
+                                            ScalarEvolution *SE) {
+  const SCEV *PreStart = getPreStartForSignExtend(AR, Ty, SE);
+  if (!PreStart)
+    return SE->getSignExtendExpr(AR->getStart(), Ty);
+
+  return SE->getAddExpr(SE->getSignExtendExpr(AR->getStepRecurrence(*SE), Ty),
+                        SE->getSignExtendExpr(PreStart, Ty));
+}
+
 const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
                                                const Type *Ty) {
   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
@@ -1097,7 +1184,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
       // If we have special knowledge that this addrec won't overflow,
       // we don't need to do any further analysis.
       if (AR->getNoWrapFlags(SCEV::FlagNSW))
-        return getAddRecExpr(getSignExtendExpr(Start, Ty),
+        return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
                              getSignExtendExpr(Step, Ty),
                              L, SCEV::FlagNSW);
 
@@ -1133,7 +1220,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
             // Cache knowledge of AR NSW, which is propagated to this AddRec.
             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
             // Return the expression with the addrec on the outside.
-            return getAddRecExpr(getSignExtendExpr(Start, Ty),
+            return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
                                  getSignExtendExpr(Step, Ty),
                                  L, AR->getNoWrapFlags());
           }
@@ -1149,7 +1236,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
             // Cache knowledge of AR NSW, which is propagated to this AddRec.
             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
             // Return the expression with the addrec on the outside.
-            return getAddRecExpr(getSignExtendExpr(Start, Ty),
+            return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
                                  getZeroExtendExpr(Step, Ty),
                                  L, AR->getNoWrapFlags());
           }
@@ -1159,34 +1246,18 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
         // the addrec is safe. Also, if the entry is guarded by a comparison
         // with the start value and the backedge is guarded by a comparison
         // with the post-inc value, the addrec is safe.
-        if (isKnownPositive(Step)) {
-          const SCEV *N = getConstant(APInt::getSignedMinValue(BitWidth) -
-                                      getSignedRange(Step).getSignedMax());
-          if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SLT, AR, N) ||
-              (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SLT, Start, N) &&
-               isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SLT,
-                                           AR->getPostIncExpr(*this), N))) {
-            // Cache knowledge of AR NSW, which is propagated to this AddRec.
-            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
-            // Return the expression with the addrec on the outside.
-            return getAddRecExpr(getSignExtendExpr(Start, Ty),
-                                 getSignExtendExpr(Step, Ty),
-                                 L, AR->getNoWrapFlags());
-          }
-        } else if (isKnownNegative(Step)) {
-          const SCEV *N = getConstant(APInt::getSignedMaxValue(BitWidth) -
-                                      getSignedRange(Step).getSignedMin());
-          if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SGT, AR, N) ||
-              (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SGT, Start, N) &&
-               isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_SGT,
-                                           AR->getPostIncExpr(*this), N))) {
-            // Cache knowledge of AR NSW, which is propagated to this AddRec.
-            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
-            // Return the expression with the addrec on the outside.
-            return getAddRecExpr(getSignExtendExpr(Start, Ty),
-                                 getSignExtendExpr(Step, Ty),
-                                 L, AR->getNoWrapFlags());
-          }
+        ICmpInst::Predicate Pred;
+        const SCEV *OverflowLimit = getOverflowLimitForStep(Step, &Pred, this);
+        if (OverflowLimit &&
+            (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
+             (isLoopEntryGuardedByCond(L, Pred, Start, OverflowLimit) &&
+              isLoopBackedgeGuardedByCond(L, Pred, AR->getPostIncExpr(*this),
+                                          OverflowLimit)))) {
+          // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec.
+          const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
+          return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
+                               getSignExtendExpr(Step, Ty),
+                               L, AR->getNoWrapFlags());
         }
       }
     }