Narrow right shifts need to encode their immediates differently from a normal
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index 922751fe3d13fd5ee859b81613584a3e48654eea..62244ccb3a03a1222e5014c25a5b0ae0ba823b82 100644 (file)
@@ -819,6 +819,36 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op,
   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
     return getTruncateOrZeroExtend(SZ->getOperand(), Ty);
 
+  // trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can
+  // eliminate all the truncates.
+  if (const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Op)) {
+    SmallVector<const SCEV *, 4> Operands;
+    bool hasTrunc = false;
+    for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) {
+      const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty);
+      hasTrunc = isa<SCEVTruncateExpr>(S);
+      Operands.push_back(S);
+    }
+    if (!hasTrunc)
+      return getAddExpr(Operands, false, false);
+    UniqueSCEVs.FindNodeOrInsertPos(ID, IP);  // Mutates IP, returns NULL.
+  }
+
+  // trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can
+  // eliminate all the truncates.
+  if (const SCEVMulExpr *SM = dyn_cast<SCEVMulExpr>(Op)) {
+    SmallVector<const SCEV *, 4> Operands;
+    bool hasTrunc = false;
+    for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) {
+      const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty);
+      hasTrunc = isa<SCEVTruncateExpr>(S);
+      Operands.push_back(S);
+    }
+    if (!hasTrunc)
+      return getMulExpr(Operands, false, false);
+    UniqueSCEVs.FindNodeOrInsertPos(ID, IP);  // Mutates IP, returns NULL.
+  }
+
   // If the input value is a chrec scev, truncate the chrec's operands.
   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
     SmallVector<const SCEV *, 4> Operands;
@@ -870,6 +900,19 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
   void *IP = 0;
   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
 
+  // zext(trunc(x)) --> zext(x) or x or trunc(x)
+  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
+    // It's possible the bits taken off by the truncate were all zero bits. If
+    // so, we should be able to simplify this further.
+    const SCEV *X = ST->getOperand();
+    ConstantRange CR = getUnsignedRange(X);
+    unsigned TruncBits = getTypeSizeInBits(ST->getType());
+    unsigned NewBits = getTypeSizeInBits(Ty);
+    if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
+            CR.zextOrTrunc(NewBits)))
+      return getTruncateOrZeroExtend(X, Ty);
+  }
+
   // If the input value is a chrec scev, and we can prove that the value
   // did not overflow the old, smaller, value, we can zero extend all of the
   // operands (often constants).  This allows analysis of something like
@@ -994,6 +1037,10 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
   if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
     return getSignExtendExpr(SS->getOperand(), Ty);
 
+  // sext(zext(x)) --> zext(x)
+  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
+    return getZeroExtendExpr(SZ->getOperand(), Ty);
+
   // Before doing any expensive analysis, check to see if we've already
   // computed a SCEV for this Op and Ty.
   FoldingSetNodeID ID;
@@ -1003,6 +1050,23 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
   void *IP = 0;
   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
 
+  // If the input value is provably positive, build a zext instead.
+  if (isKnownNonNegative(Op))
+    return getZeroExtendExpr(Op, Ty);
+
+  // sext(trunc(x)) --> sext(x) or x or trunc(x)
+  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
+    // It's possible the bits taken off by the truncate were all sign bits. If
+    // so, we should be able to simplify this further.
+    const SCEV *X = ST->getOperand();
+    ConstantRange CR = getSignedRange(X);
+    unsigned TruncBits = getTypeSizeInBits(ST->getType());
+    unsigned NewBits = getTypeSizeInBits(Ty);
+    if (CR.truncate(TruncBits).signExtend(NewBits).contains(
+            CR.sextOrTrunc(NewBits)))
+      return getTruncateOrSignExtend(X, Ty);
+  }
+
   // If the input value is a chrec scev, and we can prove that the value
   // did not overflow the old, smaller, value, we can sign extend all of the
   // operands (often constants).  This allows analysis of something like
@@ -1560,14 +1624,10 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
       AddRecOps[0] = getAddExpr(LIOps);
 
       // Build the new addrec. Propagate the NUW and NSW flags if both the
-      // outer add and the inner addrec are guaranteed to have no overflow or if
-      // there is no outer part.
-      if (Ops.size() != 1) {
-        HasNUW &= AddRec->hasNoUnsignedWrap();
-        HasNSW &= AddRec->hasNoSignedWrap();
-      }
-      
-      const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, HasNUW, HasNSW);
+      // outer add and the inner addrec are guaranteed to have no overflow.
+      const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop,
+                                         HasNUW && AddRec->hasNoUnsignedWrap(),
+                                         HasNSW && AddRec->hasNoSignedWrap());
 
       // If all of the other operands were loop invariant, we are done.
       if (Ops.size() == 1) return NewRec;
@@ -2450,8 +2510,9 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
   return getMinusSCEV(AllOnes, V);
 }
 
-/// getMinusSCEV - Return a SCEV corresponding to LHS - RHS.
-///
+/// getMinusSCEV - Return LHS-RHS.  Minus is represented in SCEV as A+B*-1,
+/// and thus the HasNUW and HasNSW bits apply to the resultant add, not
+/// whether the sub would have overflowed.
 const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
                                           bool HasNUW, bool HasNSW) {
   // Fast path: X - X --> 0.
@@ -2722,11 +2783,23 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
                   HasNUW = true;
                 if (OBO->hasNoSignedWrap())
                   HasNSW = true;
-              } else if (isa<GEPOperator>(BEValueV)) {
-                // If the increment is a GEP, then we know it won't perform an
-                // unsigned overflow, because the address space cannot be
+              } else if (const GEPOperator *GEP = 
+                            dyn_cast<GEPOperator>(BEValueV)) {
+                // If the increment is a GEP, then we know it won't perform a
+                // signed overflow, because the address space cannot be
                 // wrapped around.
-                HasNUW = true;
+                //
+                // NOTE: This isn't strictly true, because you could have an
+                // object straddling the 2G address boundary in a 32-bit address
+                // space (for example).  We really want to model this as a "has
+                // no signed/unsigned wrap" where the base pointer is treated as
+                // unsigned and the increment is known to not have signed
+                // wrapping.
+                //
+                // This is a highly theoretical concern though, and this is good
+                // enough for all cases we know of at this point. :)
+                //                
+                HasNSW |= GEP->isInBounds();
               }
 
               const SCEV *StartVal = getSCEV(StartValueV);
@@ -2797,6 +2870,7 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
   // Add expression, because the Instruction may be guarded by control flow
   // and the no-overflow bits may not be valid for the expression in any
   // context.
+  bool isInBounds = GEP->isInBounds();
 
   const Type *IntPtrTy = getEffectiveSCEVType(GEP->getType());
   Value *Base = GEP->getOperand(0);
@@ -2825,7 +2899,8 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
       IndexS = getTruncateOrSignExtend(IndexS, IntPtrTy);
 
       // Multiply the index by the element size to compute the element offset.
-      const SCEV *LocalOffset = getMulExpr(IndexS, ElementSize);
+      const SCEV *LocalOffset = getMulExpr(IndexS, ElementSize, /*NUW*/ false,
+                                           /*NSW*/ isInBounds);
 
       // Add the element offset to the running total offset.
       TotalOffset = getAddExpr(TotalOffset, LocalOffset);
@@ -2836,7 +2911,8 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
   const SCEV *BaseS = getSCEV(Base);
 
   // Add the total offset from all the GEP indices to the base.
-  return getAddExpr(BaseS, TotalOffset);
+  return getAddExpr(BaseS, TotalOffset, /*NUW*/ false,
+                    /*NSW*/ isInBounds);
 }
 
 /// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
@@ -3066,6 +3142,7 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) {
 ///
 ConstantRange
 ScalarEvolution::getSignedRange(const SCEV *S) {
+  // See if we've computed this range already.
   DenseMap<const SCEV *, ConstantRange>::iterator I = SignedRanges.find(S);
   if (I != SignedRanges.end())
     return I->second;
@@ -4024,7 +4101,7 @@ static const SCEV *getMinusSCEVForExitTest(const SCEV *LHS, const SCEV *RHS,
     cast<SCEVConstant>(RHSA->getOperand(1))->getValue();
   
   // If the strides are equal, then this is just a (complex) loop invariant
-  // comparison of a/b.
+  // comparison of a and b.
   if (LHSStride == RHSStride)
     return SE.getMinusSCEV(LHSA->getStart(), RHSA->getStart());