Use clear() to zero an existing APInt.
[oota-llvm.git] / lib / Analysis / ConstantFolding.cpp
index 8c398462d408556b84b947853bfe64f42e03e903..43597c85f16237ad35b4229c584eacd0d1b3ce3d 100644 (file)
@@ -2,8 +2,8 @@
 //
 //                     The LLVM Compiler Infrastructure
 //
-// This file was developed by the LLVM research group and is distributed under
-// the University of Illinois Open Source License. See LICENSE.TXT for details.
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
 //
 //===----------------------------------------------------------------------===//
 //
@@ -65,8 +65,9 @@ static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV,
     
     // Otherwise, add any offset that our operands provide.
     gep_type_iterator GTI = gep_type_begin(CE);
-    for (unsigned i = 1, e = CE->getNumOperands(); i != e; ++i, ++GTI) {
-      ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(i));
+    for (User::const_op_iterator i = CE->op_begin() + 1, e = CE->op_end();
+         i != e; ++i, ++GTI) {
+      ConstantInt *CI = dyn_cast<ConstantInt>(*i);
       if (!CI) return false;  // Index isn't a simple constant?
       if (CI->getZExtValue() == 0) continue;  // Not adding anything.
       
@@ -122,23 +123,144 @@ static Constant *SymbolicallyEvaluateGEP(Constant* const* Ops, unsigned NumOps,
                                          const Type *ResultTy,
                                          const TargetData *TD) {
   Constant *Ptr = Ops[0];
-  if (!cast<PointerType>(Ptr->getType())->getElementType()->isSized())
+  if (!TD || !cast<PointerType>(Ptr->getType())->getElementType()->isSized())
     return 0;
   
-  if (TD && Ptr->isNullValue()) {
-    // If this is a constant expr gep that is effectively computing an
-    // "offsetof", fold it into 'cast int Size to T*' instead of 'gep 0, 0, 12'
-    bool isFoldableGEP = true;
-    for (unsigned i = 1; i != NumOps; ++i)
-      if (!isa<ConstantInt>(Ops[i])) {
-        isFoldableGEP = false;
-        break;
+  uint64_t BasePtr = 0;
+  if (!Ptr->isNullValue()) {
+    // If this is a inttoptr from a constant int, we can fold this as the base,
+    // otherwise we can't.
+    if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr))
+      if (CE->getOpcode() == Instruction::IntToPtr)
+        if (ConstantInt *Base = dyn_cast<ConstantInt>(CE->getOperand(0)))
+          BasePtr = Base->getZExtValue();
+    
+    if (BasePtr == 0)
+      return 0;
+  }
+
+  // If this is a constant expr gep that is effectively computing an
+  // "offsetof", fold it into 'cast int Size to T*' instead of 'gep 0, 0, 12'
+  for (unsigned i = 1; i != NumOps; ++i)
+    if (!isa<ConstantInt>(Ops[i]))
+      return false;
+  
+  uint64_t Offset = TD->getIndexedOffset(Ptr->getType(),
+                                         (Value**)Ops+1, NumOps-1);
+  Constant *C = ConstantInt::get(TD->getIntPtrType(), Offset+BasePtr);
+  return ConstantExpr::getIntToPtr(C, ResultTy);
+}
+
+/// FoldBitCast - Constant fold bitcast, symbolically evaluating it with 
+/// targetdata.  Return 0 if unfoldable.
+static Constant *FoldBitCast(Constant *C, const Type *DestTy,
+                             const TargetData &TD) {
+  // If this is a bitcast from constant vector -> vector, fold it.
+  if (ConstantVector *CV = dyn_cast<ConstantVector>(C)) {
+    if (const VectorType *DestVTy = dyn_cast<VectorType>(DestTy)) {
+      // If the element types match, VMCore can fold it.
+      unsigned NumDstElt = DestVTy->getNumElements();
+      unsigned NumSrcElt = CV->getNumOperands();
+      if (NumDstElt == NumSrcElt)
+        return 0;
+      
+      const Type *SrcEltTy = CV->getType()->getElementType();
+      const Type *DstEltTy = DestVTy->getElementType();
+      
+      // Otherwise, we're changing the number of elements in a vector, which 
+      // requires endianness information to do the right thing.  For example,
+      //    bitcast (<2 x i64> <i64 0, i64 1> to <4 x i32>)
+      // folds to (little endian):
+      //    <4 x i32> <i32 0, i32 0, i32 1, i32 0>
+      // and to (big endian):
+      //    <4 x i32> <i32 0, i32 0, i32 0, i32 1>
+      
+      // First thing is first.  We only want to think about integer here, so if
+      // we have something in FP form, recast it as integer.
+      if (DstEltTy->isFloatingPoint()) {
+        // Fold to an vector of integers with same size as our FP type.
+        unsigned FPWidth = DstEltTy->getPrimitiveSizeInBits();
+        const Type *DestIVTy = VectorType::get(IntegerType::get(FPWidth),
+                                               NumDstElt);
+        // Recursively handle this integer conversion, if possible.
+        C = FoldBitCast(C, DestIVTy, TD);
+        if (!C) return 0;
+        
+        // Finally, VMCore can handle this now that #elts line up.
+        return ConstantExpr::getBitCast(C, DestTy);
       }
-    if (isFoldableGEP) {
-      uint64_t Offset = TD->getIndexedOffset(Ptr->getType(),
-                                             (Value**)Ops+1, NumOps-1);
-      Constant *C = ConstantInt::get(TD->getIntPtrType(), Offset);
-      return ConstantExpr::getIntToPtr(C, ResultTy);
+      
+      // Okay, we know the destination is integer, if the input is FP, convert
+      // it to integer first.
+      if (SrcEltTy->isFloatingPoint()) {
+        unsigned FPWidth = SrcEltTy->getPrimitiveSizeInBits();
+        const Type *SrcIVTy = VectorType::get(IntegerType::get(FPWidth),
+                                              NumSrcElt);
+        // Ask VMCore to do the conversion now that #elts line up.
+        C = ConstantExpr::getBitCast(C, SrcIVTy);
+        CV = dyn_cast<ConstantVector>(C);
+        if (!CV) return 0;  // If VMCore wasn't able to fold it, bail out.
+      }
+      
+      // Now we know that the input and output vectors are both integer vectors
+      // of the same size, and that their #elements is not the same.  Do the
+      // conversion here, which depends on whether the input or output has
+      // more elements.
+      bool isLittleEndian = TD.isLittleEndian();
+      
+      SmallVector<Constant*, 32> Result;
+      if (NumDstElt < NumSrcElt) {
+        // Handle: bitcast (<4 x i32> <i32 0, i32 1, i32 2, i32 3> to <2 x i64>)
+        Constant *Zero = Constant::getNullValue(DstEltTy);
+        unsigned Ratio = NumSrcElt/NumDstElt;
+        unsigned SrcBitSize = SrcEltTy->getPrimitiveSizeInBits();
+        unsigned SrcElt = 0;
+        for (unsigned i = 0; i != NumDstElt; ++i) {
+          // Build each element of the result.
+          Constant *Elt = Zero;
+          unsigned ShiftAmt = isLittleEndian ? 0 : SrcBitSize*(Ratio-1);
+          for (unsigned j = 0; j != Ratio; ++j) {
+            Constant *Src = dyn_cast<ConstantInt>(CV->getOperand(SrcElt++));
+            if (!Src) return 0;  // Reject constantexpr elements.
+            
+            // Zero extend the element to the right size.
+            Src = ConstantExpr::getZExt(Src, Elt->getType());
+            
+            // Shift it to the right place, depending on endianness.
+            Src = ConstantExpr::getShl(Src, 
+                                    ConstantInt::get(Src->getType(), ShiftAmt));
+            ShiftAmt += isLittleEndian ? SrcBitSize : -SrcBitSize;
+            
+            // Mix it in.
+            Elt = ConstantExpr::getOr(Elt, Src);
+          }
+          Result.push_back(Elt);
+        }
+      } else {
+        // Handle: bitcast (<2 x i64> <i64 0, i64 1> to <4 x i32>)
+        unsigned Ratio = NumDstElt/NumSrcElt;
+        unsigned DstBitSize = DstEltTy->getPrimitiveSizeInBits();
+        
+        // Loop over each source value, expanding into multiple results.
+        for (unsigned i = 0; i != NumSrcElt; ++i) {
+          Constant *Src = dyn_cast<ConstantInt>(CV->getOperand(i));
+          if (!Src) return 0;  // Reject constantexpr elements.
+
+          unsigned ShiftAmt = isLittleEndian ? 0 : DstBitSize*(Ratio-1);
+          for (unsigned j = 0; j != Ratio; ++j) {
+            // Shift the piece of the value into the right place, depending on
+            // endianness.
+            Constant *Elt = ConstantExpr::getLShr(Src, 
+                                ConstantInt::get(Src->getType(), ShiftAmt));
+            ShiftAmt += isLittleEndian ? DstBitSize : -DstBitSize;
+
+            // Truncate and remember this piece.
+            Result.push_back(ConstantExpr::getTrunc(Elt, DstEltTy));
+          }
+        }
+      }
+      
+      return ConstantVector::get(&Result[0], Result.size());
     }
   }
   
@@ -176,8 +298,8 @@ Constant *llvm::ConstantFoldInstruction(Instruction *I, const TargetData *TD) {
   // Scan the operand list, checking to see if they are all constants, if so,
   // hand off to ConstantFoldInstOperands.
   SmallVector<Constant*, 8> Ops;
-  for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i)
-    if (Constant *Op = dyn_cast<Constant>(I->getOperand(i)))
+  for (User::op_iterator i = I->op_begin(), e = I->op_end(); i != e; ++i)
+    if (Constant *Op = dyn_cast<Constant>(*i))
       Ops.push_back(Op);
     else
       return 0;  // All operands not constant!
@@ -190,6 +312,25 @@ Constant *llvm::ConstantFoldInstruction(Instruction *I, const TargetData *TD) {
                                     &Ops[0], Ops.size(), TD);
 }
 
+/// ConstantFoldConstantExpression - Attempt to fold the constant expression
+/// using the specified TargetData.  If successful, the constant result is
+/// result is returned, if not, null is returned.
+Constant *llvm::ConstantFoldConstantExpression(ConstantExpr *CE,
+                                               const TargetData *TD) {
+  assert(TD && "ConstantFoldConstantExpression requires a valid TargetData.");
+
+  SmallVector<Constant*, 8> Ops;
+  for (User::op_iterator i = CE->op_begin(), e = CE->op_end(); i != e; ++i)
+    Ops.push_back(cast<Constant>(*i));
+
+  if (CE->isCompare())
+    return ConstantFoldCompareInstOperands(CE->getPredicate(),
+                                           &Ops[0], Ops.size(), TD);
+  else 
+    return ConstantFoldInstOperands(CE->getOpcode(), CE->getType(),
+                                    &Ops[0], Ops.size(), TD);
+}
+
 /// ConstantFoldInstOperands - Attempt to constant fold an instruction with the
 /// specified opcode and operands.  If successful, the constant result is
 /// returned, if not, null is returned.  Note that this function can fail when
@@ -233,7 +374,7 @@ Constant *llvm::ConstantFoldInstOperands(unsigned Opcode, const Type *DestTy,
         return ConstantExpr::getIntegerCast(Input, DestTy, false);
       }
     }
-    // FALL THROUGH.
+    return ConstantExpr::getCast(Opcode, Ops[0], DestTy);
   case Instruction::IntToPtr:
   case Instruction::Trunc:
   case Instruction::ZExt:
@@ -244,8 +385,12 @@ Constant *llvm::ConstantFoldInstOperands(unsigned Opcode, const Type *DestTy,
   case Instruction::SIToFP:
   case Instruction::FPToUI:
   case Instruction::FPToSI:
+      return ConstantExpr::getCast(Opcode, Ops[0], DestTy);
   case Instruction::BitCast:
-    return ConstantExpr::getCast(Opcode, Ops[0], DestTy);
+    if (TD)
+      if (Constant *C = FoldBitCast(Ops[0], DestTy, *TD))
+        return C;
+    return ConstantExpr::getBitCast(Ops[0], DestTy);
   case Instruction::Select:
     return ConstantExpr::getSelect(Ops[0], Ops[1], Ops[2]);
   case Instruction::ExtractElement:
@@ -272,7 +417,7 @@ Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate,
                                                 const TargetData *TD) {
   // fold: icmp (inttoptr x), null         -> icmp x, 0
   // fold: icmp (ptrtoint x), 0            -> icmp x, null
-  // fold: icmp (inttoptr x), (inttoptr y) -> icmp x, y
+  // fold: icmp (inttoptr x), (inttoptr y) -> icmp trunc/zext x, trunc/zext y
   // fold: icmp (ptrtoint x), (ptrtoint y) -> icmp x, y
   //
   // ConstantExpr::getCompare cannot do this, because it doesn't have TD
@@ -300,21 +445,31 @@ Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate,
       }
     }
     
-    if (TD && isa<ConstantExpr>(Ops[1]) &&
-        cast<ConstantExpr>(Ops[1])->getOpcode() == CE0->getOpcode()) {
-      const Type *IntPtrTy = TD->getIntPtrType();
-      // Only do this transformation if the int is intptrty in size, otherwise
-      // there is a truncation or extension that we aren't modeling.
-      if ((CE0->getOpcode() == Instruction::IntToPtr &&
-           CE0->getOperand(0)->getType() == IntPtrTy &&
-           CE0->getOperand(1)->getType() == IntPtrTy) ||
-          (CE0->getOpcode() == Instruction::PtrToInt &&
-           CE0->getType() == IntPtrTy &&
-           CE0->getOperand(0)->getType() == CE0->getOperand(1)->getType())) {
-        Constant *NewOps[] = { 
-          CE0->getOperand(0), cast<ConstantExpr>(Ops[1])->getOperand(0) 
-        };
-        return ConstantFoldCompareInstOperands(Predicate, NewOps, 2, TD);
+    if (ConstantExpr *CE1 = dyn_cast<ConstantExpr>(Ops[1])) {
+      if (TD && CE0->getOpcode() == CE1->getOpcode()) {
+        const Type *IntPtrTy = TD->getIntPtrType();
+
+        if (CE0->getOpcode() == Instruction::IntToPtr) {
+          // Convert the integer value to the right size to ensure we get the
+          // proper extension or truncation.
+          Constant *C0 = ConstantExpr::getIntegerCast(CE0->getOperand(0),
+                                                      IntPtrTy, false);
+          Constant *C1 = ConstantExpr::getIntegerCast(CE1->getOperand(0),
+                                                      IntPtrTy, false);
+          Constant *NewOps[] = { C0, C1 };
+          return ConstantFoldCompareInstOperands(Predicate, NewOps, 2, TD);
+        }
+
+        // Only do this transformation if the int is intptrty in size, otherwise
+        // there is a truncation or extension that we aren't modeling.
+        if ((CE0->getOpcode() == Instruction::PtrToInt &&
+             CE0->getType() == IntPtrTy &&
+             CE0->getOperand(0)->getType() == CE1->getOperand(0)->getType())) {
+          Constant *NewOps[] = { 
+            CE0->getOperand(0), CE1->getOperand(0) 
+          };
+          return ConstantFoldCompareInstOperands(Predicate, NewOps, 2, TD);
+        }
       }
     }
   }
@@ -388,7 +543,7 @@ Constant *llvm::ConstantFoldLoadThroughGEPConstantExpr(Constant *C,
 /// canConstantFoldCallTo - Return true if its even possible to fold a call to
 /// the specified function.
 bool
-llvm::canConstantFoldCallTo(Function *F) {
+llvm::canConstantFoldCallTo(const Function *F) {
   switch (F->getIntrinsicID()) {
   case Intrinsic::sqrt:
   case Intrinsic::powi:
@@ -449,7 +604,8 @@ llvm::canConstantFoldCallTo(Function *F) {
     if (Len == 3)
       return !strcmp(Str, "sin");
     if (Len == 4)
-      return !strcmp(Str, "sinh") || !strcmp(Str, "sqrt");
+      return !strcmp(Str, "sinh") || !strcmp(Str, "sqrt") ||
+             !strcmp(Str, "sinf");
     if (Len == 5)
       return !strcmp(Str, "sqrtf");
     return false;
@@ -466,16 +622,17 @@ static Constant *ConstantFoldFP(double (*NativeFP)(double), double V,
                                 const Type *Ty) {
   errno = 0;
   V = NativeFP(V);
-  if (errno == 0) {
-    if (Ty==Type::FloatTy)
-      return ConstantFP::get(Ty, APFloat((float)V));
-    else if (Ty==Type::DoubleTy)
-      return ConstantFP::get(Ty, APFloat(V));
-    else
-      assert(0);
+  if (errno != 0) {
+    errno = 0;
+    return 0;
   }
-  errno = 0;
-  return 0;
+  
+  if (Ty == Type::FloatTy)
+    return ConstantFP::get(APFloat((float)V));
+  if (Ty == Type::DoubleTy)
+    return ConstantFP::get(APFloat(V));
+  assert(0 && "Can only constant fold float/double");
+  return 0; // dummy return to suppress warning
 }
 
 static Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double),
@@ -483,16 +640,17 @@ static Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double),
                                       const Type *Ty) {
   errno = 0;
   V = NativeFP(V, W);
-  if (errno == 0) {
-    if (Ty==Type::FloatTy)
-      return ConstantFP::get(Ty, APFloat((float)V));
-    else if (Ty==Type::DoubleTy)
-      return ConstantFP::get(Ty, APFloat(V));
-    else
-      assert(0);
+  if (errno != 0) {
+    errno = 0;
+    return 0;
   }
-  errno = 0;
-  return 0;
+  
+  if (Ty == Type::FloatTy)
+    return ConstantFP::get(APFloat((float)V));
+  if (Ty == Type::DoubleTy)
+    return ConstantFP::get(APFloat(V));
+  assert(0 && "Can only constant fold float/double");
+  return 0; // dummy return to suppress warning
 }
 
 /// ConstantFoldCall - Attempt to constant fold a call to the specified function
@@ -533,6 +691,8 @@ llvm::ConstantFoldCall(Function *F,
           return ConstantFoldFP(cos, V, Ty);
         else if (Len == 4 && !strcmp(Str, "cosh"))
           return ConstantFoldFP(cosh, V, Ty);
+        else if (Len == 4 && !strcmp(Str, "cosf"))
+          return ConstantFoldFP(cos, V, Ty);
         break;
       case 'e':
         if (Len == 3 && !strcmp(Str, "exp"))
@@ -554,8 +714,7 @@ llvm::ConstantFoldCall(Function *F,
           if (V >= -0.0)
             return ConstantFoldFP(sqrt, V, Ty);
           else // Undefined
-            return ConstantFP::get(Ty, Ty==Type::FloatTy ? APFloat(0.0f) :
-                                       APFloat(0.0));
+            return Constant::getNullValue(Ty);
         }
         break;
       case 's':
@@ -567,6 +726,8 @@ llvm::ConstantFoldCall(Function *F,
           return ConstantFoldFP(sqrt, V, Ty);
         else if (Len == 5 && !strcmp(Str, "sqrtf") && V >= 0)
           return ConstantFoldFP(sqrt, V, Ty);
+        else if (Len == 4 && !strcmp(Str, "sinf"))
+          return ConstantFoldFP(sin, V, Ty);
         break;
       case 't':
         if (Len == 3 && !strcmp(Str, "tan"))
@@ -608,11 +769,11 @@ llvm::ConstantFoldCall(Function *F,
         }
       } else if (ConstantInt *Op2C = dyn_cast<ConstantInt>(Operands[1])) {
         if (!strcmp(Str, "llvm.powi.f32")) {
-          return ConstantFP::get(Ty, APFloat((float)std::pow((float)Op1V,
-                                              (int)Op2C->getZExtValue())));
+          return ConstantFP::get(APFloat((float)std::pow((float)Op1V,
+                                                 (int)Op2C->getZExtValue())));
         } else if (!strcmp(Str, "llvm.powi.f64")) {
-          return ConstantFP::get(Ty, APFloat((double)std::pow((double)Op1V,
-                                              (int)Op2C->getZExtValue())));
+          return ConstantFP::get(APFloat((double)std::pow((double)Op1V,
+                                                 (int)Op2C->getZExtValue())));
         }
       }
     }