Boost the power of phi node constant folding slightly: if all
[oota-llvm.git] / lib / Analysis / ConstantFolding.cpp
index 0ac512726ac18fec2afbbaa952bd6da5d0d7355a..f2e89a773e7dc945f180a9e4aa90ad22f5214ea4 100644 (file)
@@ -30,6 +30,7 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/GetElementPtrTypeIterator.h"
 #include "llvm/Support/MathExtras.h"
+#include "llvm/System/FEnv.h"
 #include <cerrno>
 #include <cmath>
 using namespace llvm;
@@ -208,7 +209,7 @@ static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV,
          i != e; ++i, ++GTI) {
       ConstantInt *CI = dyn_cast<ConstantInt>(*i);
       if (!CI) return false;  // Index isn't a simple constant?
-      if (CI->getZExtValue() == 0) continue;  // Not adding anything.
+      if (CI->isZero()) continue;  // Not adding anything.
       
       if (const StructType *ST = dyn_cast<StructType>(*GTI)) {
         // N = N + Offset
@@ -401,7 +402,7 @@ static Constant *FoldReinterpretLoadFromConstPtr(Constant *C,
   APInt ResultVal = APInt(IntType->getBitWidth(), RawBytes[BytesLoaded-1]);
   for (unsigned i = 1; i != BytesLoaded; ++i) {
     ResultVal <<= 8;
-    ResultVal |= APInt(IntType->getBitWidth(), RawBytes[BytesLoaded-1-i]);
+    ResultVal |= RawBytes[BytesLoaded-1-i];
   }
 
   return ConstantInt::get(IntType->getContext(), ResultVal);
@@ -436,8 +437,10 @@ Constant *llvm::ConstantFoldLoadFromConstPtr(Constant *C,
     unsigned StrLen = Str.length();
     const Type *Ty = cast<PointerType>(CE->getType())->getElementType();
     unsigned NumBits = Ty->getPrimitiveSizeInBits();
-    // Replace LI with immediate integer store.
-    if ((NumBits >> 3) == StrLen + 1) {
+    // Replace load with immediate integer if the result is an integer or fp
+    // value.
+    if ((NumBits >> 3) == StrLen + 1 && (NumBits & 7) == 0 &&
+        (isa<IntegerType>(Ty) || Ty->isFloatingPointTy())) {
       APInt StrVal(NumBits, 0);
       APInt SingleChar(NumBits, 0);
       if (TD->isLittleEndian()) {
@@ -454,7 +457,11 @@ Constant *llvm::ConstantFoldLoadFromConstPtr(Constant *C,
         SingleChar = 0;
         StrVal = (StrVal << 8) | SingleChar;
       }
-      return ConstantInt::get(CE->getContext(), StrVal);
+      
+      Constant *Res = ConstantInt::get(CE->getContext(), StrVal);
+      if (Ty->isFloatingPointTy())
+        Res = ConstantExpr::getBitCast(Res, Ty);
+      return Res;
     }
   }
   
@@ -688,20 +695,26 @@ static Constant *SymbolicallyEvaluateGEP(Constant *const *Ops, unsigned NumOps,
 /// instructions like loads and stores, which have no constant expression form.
 ///
 Constant *llvm::ConstantFoldInstruction(Instruction *I, const TargetData *TD) {
+  // Handle PHI nodes specially here...
   if (PHINode *PN = dyn_cast<PHINode>(I)) {
-    if (PN->getNumIncomingValues() == 0)
-      return UndefValue::get(PN->getType());
-
-    Constant *Result = dyn_cast<Constant>(PN->getIncomingValue(0));
-    if (Result == 0) return 0;
-
-    // Handle PHI nodes specially here...
-    for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i)
-      if (PN->getIncomingValue(i) != Result && PN->getIncomingValue(i) != PN)
-        return 0;   // Not all the same incoming constants...
+    Constant *CommonValue = 0;
+
+    for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
+      Value *Incoming = PN->getIncomingValue(i);
+      // If the incoming value is equal to the phi node itself or is undef then
+      // skip it.
+      if (Incoming == PN || isa<UndefValue>(Incoming))
+        continue;
+      // If the incoming value is not a constant, or is a different constant to
+      // the one we saw previously, then give up.
+      Constant *C = dyn_cast<Constant>(Incoming);
+      if (!C || (CommonValue && C != CommonValue))
+        return 0;
+      CommonValue = C;
+    }
 
-    // If we reach here, all incoming values are the same constant.
-    return Result;
+    // If we reach here, all incoming values are the same constant or undef.
+    return CommonValue ? CommonValue : UndefValue::get(PN->getType());
   }
 
   // Scan the operand list, checking to see if they are all constants, if so,
@@ -772,9 +785,9 @@ Constant *llvm::ConstantFoldInstOperands(unsigned Opcode, const Type *DestTy,
   case Instruction::ICmp:
   case Instruction::FCmp: assert(0 && "Invalid for compares");
   case Instruction::Call:
-    if (Function *F = dyn_cast<Function>(Ops[0]))
+    if (Function *F = dyn_cast<Function>(Ops[NumOps - 1]))
       if (canConstantFoldCallTo(F))
-        return ConstantFoldCall(F, Ops+1, NumOps-1);
+        return ConstantFoldCall(F, Ops, NumOps - 1);
     return 0;
   case Instruction::PtrToInt:
     // If the input is a inttoptr, eliminate the pair.  This requires knowing
@@ -994,6 +1007,9 @@ llvm::canConstantFoldCallTo(const Function *F) {
   case Intrinsic::usub_with_overflow:
   case Intrinsic::sadd_with_overflow:
   case Intrinsic::ssub_with_overflow:
+  case Intrinsic::smul_with_overflow:
+  case Intrinsic::convert_from_fp16:
+  case Intrinsic::convert_to_fp16:
     return true;
   default:
     return false;
@@ -1031,10 +1047,10 @@ llvm::canConstantFoldCallTo(const Function *F) {
 
 static Constant *ConstantFoldFP(double (*NativeFP)(double), double V, 
                                 const Type *Ty) {
-  errno = 0;
+  sys::llvm_fenv_clearexcept();
   V = NativeFP(V);
-  if (errno != 0) {
-    errno = 0;
+  if (sys::llvm_fenv_testexcept()) {
+    sys::llvm_fenv_clearexcept();
     return 0;
   }
   
@@ -1048,10 +1064,10 @@ static Constant *ConstantFoldFP(double (*NativeFP)(double), double V,
 
 static Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double),
                                       double V, double W, const Type *Ty) {
-  errno = 0;
+  sys::llvm_fenv_clearexcept();
   V = NativeFP(V, W);
-  if (errno != 0) {
-    errno = 0;
+  if (sys::llvm_fenv_testexcept()) {
+    sys::llvm_fenv_clearexcept();
     return 0;
   }
   
@@ -1074,8 +1090,24 @@ llvm::ConstantFoldCall(Function *F,
   const Type *Ty = F->getReturnType();
   if (NumOperands == 1) {
     if (ConstantFP *Op = dyn_cast<ConstantFP>(Operands[0])) {
+      if (Name == "llvm.convert.to.fp16") {
+        APFloat Val(Op->getValueAPF());
+
+        bool lost = false;
+        Val.convert(APFloat::IEEEhalf, APFloat::rmNearestTiesToEven, &lost);
+
+        return ConstantInt::get(F->getContext(), Val.bitcastToAPInt());
+      }
+
       if (!Ty->isFloatTy() && !Ty->isDoubleTy())
         return 0;
+
+      /// We only fold functions with finite arguments. Folding NaN and inf is
+      /// likely to be aborted with an exception anyway, and some host libms
+      /// have known errors raising exceptions.
+      if (Op->getValueAPF().isNaN() || Op->getValueAPF().isInfinity())
+        return 0;
+
       /// Currently APFloat versions of these functions do not exist, so we use
       /// the host native double versions.  Float versions are not called
       /// directly but for all these it is true (float)(f((double)arg)) ==
@@ -1158,6 +1190,20 @@ llvm::ConstantFoldCall(Function *F,
         return ConstantInt::get(Ty, Op->getValue().countTrailingZeros());
       else if (Name.startswith("llvm.ctlz"))
         return ConstantInt::get(Ty, Op->getValue().countLeadingZeros());
+      else if (Name == "llvm.convert.from.fp16") {
+        APFloat Val(Op->getValue());
+
+        bool lost = false;
+        APFloat::opStatus status =
+          Val.convert(APFloat::IEEEsingle, APFloat::rmNearestTiesToEven, &lost);
+
+        // Conversion is always precise.
+        status = status;
+        assert(status == APFloat::opOK && !lost &&
+               "Precision lost during fp16 constfolding");
+
+        return ConstantFP::get(F->getContext(), Val);
+      }
       return 0;
     }
     
@@ -1209,42 +1255,37 @@ llvm::ConstantFoldCall(Function *F,
       if (ConstantInt *Op2 = dyn_cast<ConstantInt>(Operands[1])) {
         switch (F->getIntrinsicID()) {
         default: break;
-        case Intrinsic::uadd_with_overflow: {
-          Constant *Res = ConstantExpr::getAdd(Op1, Op2);           // result.
-          Constant *Ops[] = {
-            Res, ConstantExpr::getICmp(CmpInst::ICMP_ULT, Res, Op1) // overflow.
-          };
-          return ConstantStruct::get(F->getContext(), Ops, 2, false);
-        }
-        case Intrinsic::usub_with_overflow: {
-          Constant *Res = ConstantExpr::getSub(Op1, Op2);           // result.
+        case Intrinsic::sadd_with_overflow:
+        case Intrinsic::uadd_with_overflow:
+        case Intrinsic::ssub_with_overflow:
+        case Intrinsic::usub_with_overflow:
+        case Intrinsic::smul_with_overflow: {
+          APInt Res;
+          bool Overflow;
+          switch (F->getIntrinsicID()) {
+          default: assert(0 && "Invalid case");
+          case Intrinsic::sadd_with_overflow:
+            Res = Op1->getValue().sadd_ov(Op2->getValue(), Overflow);
+            break;
+          case Intrinsic::uadd_with_overflow:
+            Res = Op1->getValue().uadd_ov(Op2->getValue(), Overflow);
+            break;
+          case Intrinsic::ssub_with_overflow:
+            Res = Op1->getValue().ssub_ov(Op2->getValue(), Overflow);
+            break;
+          case Intrinsic::usub_with_overflow:
+            Res = Op1->getValue().usub_ov(Op2->getValue(), Overflow);
+            break;
+          case Intrinsic::smul_with_overflow:
+            Res = Op1->getValue().smul_ov(Op2->getValue(), Overflow);
+            break;
+          }
           Constant *Ops[] = {
-            Res, ConstantExpr::getICmp(CmpInst::ICMP_UGT, Res, Op1) // overflow.
+            ConstantInt::get(F->getContext(), Res),
+            ConstantInt::get(Type::getInt1Ty(F->getContext()), Overflow)
           };
           return ConstantStruct::get(F->getContext(), Ops, 2, false);
         }
-        case Intrinsic::sadd_with_overflow: {
-          Constant *Res = ConstantExpr::getAdd(Op1, Op2);           // result.
-          Constant *Overflow = ConstantExpr::getSelect(
-              ConstantExpr::getICmp(CmpInst::ICMP_SGT,
-                ConstantInt::get(Op1->getType(), 0), Op1),
-              ConstantExpr::getICmp(CmpInst::ICMP_SGT, Res, Op2), 
-              ConstantExpr::getICmp(CmpInst::ICMP_SLT, Res, Op2)); // overflow.
-
-          Constant *Ops[] = { Res, Overflow };
-          return ConstantStruct::get(F->getContext(), Ops, 2, false);
-        }
-        case Intrinsic::ssub_with_overflow: {
-          Constant *Res = ConstantExpr::getSub(Op1, Op2);           // result.
-          Constant *Overflow = ConstantExpr::getSelect(
-              ConstantExpr::getICmp(CmpInst::ICMP_SGT,
-                ConstantInt::get(Op2->getType(), 0), Op2),
-              ConstantExpr::getICmp(CmpInst::ICMP_SLT, Res, Op1), 
-              ConstantExpr::getICmp(CmpInst::ICMP_SGT, Res, Op1)); // overflow.
-
-          Constant *Ops[] = { Res, Overflow };
-          return ConstantStruct::get(F->getContext(), Ops, 2, false);
-        }
         }
       }