add integer overflow check for the fp induction variable
[oota-llvm.git] / lib / Transforms / Scalar / IndVarSimplify.cpp
index c2035f4b85e3d913d4504e57f25da00dbef7c93c..d32082e5dcd283b1c1c56649d7c81114c328d7dc 100644 (file)
@@ -653,10 +653,9 @@ void IndVarSimplify::HandleFloatingPointIV(Loop *L, PHINode *PN) {
   // Check incoming value.
   ConstantFP *InitValueVal =
     dyn_cast<ConstantFP>(PN->getIncomingValue(IncomingEdge));
-  if (!InitValueVal) return;
-  
+
   int64_t InitValue;
-  if (!ConvertToSInt(InitValueVal->getValueAPF(), InitValue))
+  if (!InitValueVal || !ConvertToSInt(InitValueVal->getValueAPF(), InitValue))
     return;
 
   // Check IV increment. Reject this PN if increment operation is not
@@ -668,9 +667,9 @@ void IndVarSimplify::HandleFloatingPointIV(Loop *L, PHINode *PN) {
   // If this is not an add of the PHI with a constantfp, or if the constant fp
   // is not an integer, bail out.
   ConstantFP *IncValueVal = dyn_cast<ConstantFP>(Incr->getOperand(1));
-  int64_t IntValue;
+  int64_t IncValue;
   if (IncValueVal == 0 || Incr->getOperand(0) != PN ||
-      !ConvertToSInt(IncValueVal->getValueAPF(), IntValue))
+      !ConvertToSInt(IncValueVal->getValueAPF(), IncValue))
     return;
 
   // Check Incr uses. One user is PN and the other user is an exit condition
@@ -692,6 +691,10 @@ void IndVarSimplify::HandleFloatingPointIV(Loop *L, PHINode *PN) {
   
   BranchInst *TheBr = cast<BranchInst>(Compare->use_back());
 
+  // FIXME: Need to verify that the branch actually controls the iteration count
+  // of the loop.  If not, the new IV can overflow and noone will notice.
+  
+  
   // If it isn't a comparison with an integer-as-fp (the exit value), we can't
   // transform it.
   ConstantFP *ExitValueVal = dyn_cast<ConstantFP>(Compare->getOperand(1));
@@ -700,22 +703,14 @@ void IndVarSimplify::HandleFloatingPointIV(Loop *L, PHINode *PN) {
       !ConvertToSInt(ExitValueVal->getValueAPF(), ExitValue))
     return;
   
-  // We convert the floating point induction variable to a signed i32 value if
-  // we can.  This is only safe if the comparison will not overflow in a way
-  // that won't be trapped by the integer equivalent operations.  Check for this
-  // now.
-  // TODO: We could use i64 if it is native and the range requires it.
-
-  
-  
-  const IntegerType *Int32Ty = Type::getInt32Ty(PN->getContext());
-
   // Find new predicate for integer comparison.
   CmpInst::Predicate NewPred = CmpInst::BAD_ICMP_PREDICATE;
   switch (Compare->getPredicate()) {
   default: return;  // Unknown comparison.
   case CmpInst::FCMP_OEQ:
   case CmpInst::FCMP_UEQ: NewPred = CmpInst::ICMP_EQ; break;
+  case CmpInst::FCMP_ONE:
+  case CmpInst::FCMP_UNE: NewPred = CmpInst::ICMP_NE; break;
   case CmpInst::FCMP_OGT:
   case CmpInst::FCMP_UGT: NewPred = CmpInst::ICMP_SGT; break;
   case CmpInst::FCMP_OGE:
@@ -725,6 +720,78 @@ void IndVarSimplify::HandleFloatingPointIV(Loop *L, PHINode *PN) {
   case CmpInst::FCMP_OLE:
   case CmpInst::FCMP_ULE: NewPred = CmpInst::ICMP_SLE; break;
   }
+  
+  // We convert the floating point induction variable to a signed i32 value if
+  // we can.  This is only safe if the comparison will not overflow in a way
+  // that won't be trapped by the integer equivalent operations.  Check for this
+  // now.
+  // TODO: We could use i64 if it is native and the range requires it.
+  
+  // The start/stride/exit values must all fit in signed i32.
+  if (!isInt<32>(InitValue) || !isInt<32>(IncValue) || !isInt<32>(ExitValue))
+    return;
+
+  // If not actually striding (add x, 0.0), avoid touching the code.
+  if (IncValue == 0)
+    return;
+
+  // Positive and negative strides have different safety conditions.
+  if (IncValue > 0) {
+    // If we have a positive stride, we require the init to be less than the
+    // exit value and an equality or less than comparison.
+    if (InitValue >= ExitValue ||
+        NewPred == CmpInst::ICMP_SGT || NewPred == CmpInst::ICMP_SGE)
+      return;
+    
+    uint32_t Range = uint32_t(ExitValue-InitValue);
+    if (NewPred == CmpInst::ICMP_SLE) {
+      // Normalize SLE -> SLT, check for infinite loop.
+      if (++Range == 0) return;  // Range overflows.
+    }
+    
+    unsigned Leftover = Range % uint32_t(IncValue);
+    
+    // If this is an equality comparison, we require that the strided value
+    // exactly land on the exit value, otherwise the IV condition will wrap
+    // around and do things the fp IV wouldn't.
+    if ((NewPred == CmpInst::ICMP_EQ || NewPred == CmpInst::ICMP_NE) &&
+        Leftover != 0)
+      return;
+    
+    // If the stride would wrap around the i32 before exiting, we can't
+    // transform the IV.
+    if (Leftover != 0 && int32_t(ExitValue+IncValue) < ExitValue)
+      return;
+    
+  } else {
+    // If we have a negative stride, we require the init to be greater than the
+    // exit value and an equality or greater than comparison.
+    if (InitValue >= ExitValue ||
+        NewPred == CmpInst::ICMP_SLT || NewPred == CmpInst::ICMP_SLE)
+      return;
+    
+    uint32_t Range = uint32_t(InitValue-ExitValue);
+    if (NewPred == CmpInst::ICMP_SGE) {
+      // Normalize SGE -> SGT, check for infinite loop.
+      if (++Range == 0) return;  // Range overflows.
+    }
+    
+    unsigned Leftover = Range % uint32_t(-IncValue);
+    
+    // If this is an equality comparison, we require that the strided value
+    // exactly land on the exit value, otherwise the IV condition will wrap
+    // around and do things the fp IV wouldn't.
+    if ((NewPred == CmpInst::ICMP_EQ || NewPred == CmpInst::ICMP_NE) &&
+        Leftover != 0)
+      return;
+    
+    // If the stride would wrap around the i32 before exiting, we can't
+    // transform the IV.
+    if (Leftover != 0 && int32_t(ExitValue+IncValue) > ExitValue)
+      return;
+  }
+  
+  const IntegerType *Int32Ty = Type::getInt32Ty(PN->getContext());
 
   // Insert new integer induction variable.
   PHINode *NewPHI = PHINode::Create(Int32Ty, PN->getName()+".int", PN);
@@ -732,7 +799,7 @@ void IndVarSimplify::HandleFloatingPointIV(Loop *L, PHINode *PN) {
                       PN->getIncomingBlock(IncomingEdge));
 
   Value *NewAdd =
-    BinaryOperator::CreateAdd(NewPHI, ConstantInt::get(Int32Ty, IntValue),
+    BinaryOperator::CreateAdd(NewPHI, ConstantInt::get(Int32Ty, IncValue),
                               Incr->getName()+".int", Incr);
   NewPHI->addIncoming(NewAdd, PN->getIncomingBlock(BackEdge));