Compare all 4 bytes of the header.
[oota-llvm.git] / lib / VMCore / ConstantFold.cpp
index 8dc2caafe83a4cb4f1b05ec2114d97cbe468231f..b7a1350ff5ad02104ff1af33281541f2f37366aa 100644 (file)
@@ -24,6 +24,7 @@
 #include "llvm/Function.h"
 #include "llvm/GlobalAlias.h"
 #include "llvm/GlobalVariable.h"
+#include "llvm/Operator.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -42,6 +43,10 @@ using namespace llvm;
 /// input vector constant are all simple integer or FP values.
 static Constant *BitCastConstantVector(ConstantVector *CV,
                                        const VectorType *DstTy) {
+
+  if (CV->isAllOnesValue()) return Constant::getAllOnesValue(DstTy);
+  if (CV->isNullValue()) return Constant::getNullValue(DstTy);
+
   // If this cast changes element count then we can't handle it here:
   // doing so requires endianness information.  This should be handled by
   // Analysis/ConstantFolding.cpp
@@ -145,7 +150,7 @@ static Constant *FoldBitCast(Constant *V, const Type *DestTy) {
     // This allows for other simplifications (although some of them
     // can only be handled by Analysis/ConstantFolding.cpp).
     if (isa<ConstantInt>(V) || isa<ConstantFP>(V))
-      return ConstantExpr::getBitCast(ConstantVector::get(&V, 1), DestPTy);
+      return ConstantExpr::getBitCast(ConstantVector::get(V), DestPTy);
   }
 
   // Finally, implement bitcast folding now.   The code below doesn't handle
@@ -554,7 +559,7 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
       for (unsigned i = 0, e = CV->getType()->getNumElements(); i != e; ++i)
         res.push_back(ConstantExpr::getCast(opc,
                                             CV->getOperand(i), DstEltTy));
-      return ConstantVector::get(DestVecTy, res);
+      return ConstantVector::get(res);
     }
 
   // We actually have to do a cast now. Perform the cast according to the
@@ -689,9 +694,48 @@ Constant *llvm::ConstantFoldSelectInstruction(Constant *Cond,
   if (ConstantInt *CB = dyn_cast<ConstantInt>(Cond))
     return CB->getZExtValue() ? V1 : V2;
 
+  // Check for zero aggregate and ConstantVector of zeros
+  if (Cond->isNullValue()) return V2;
+
+  if (ConstantVector* CondV = dyn_cast<ConstantVector>(Cond)) {
+
+    if (CondV->isAllOnesValue()) return V1;
+
+    const VectorType *VTy = cast<VectorType>(V1->getType());
+    ConstantVector *CP1 = dyn_cast<ConstantVector>(V1);
+    ConstantVector *CP2 = dyn_cast<ConstantVector>(V2);
+
+    if ((CP1 || isa<ConstantAggregateZero>(V1)) &&
+        (CP2 || isa<ConstantAggregateZero>(V2))) {
+
+      // Find the element type of the returned vector
+      const Type *EltTy = VTy->getElementType();
+      unsigned NumElem = VTy->getNumElements();
+      std::vector<Constant*> Res(NumElem);
+
+      bool Valid = true;
+      for (unsigned i = 0; i < NumElem; ++i) {
+        ConstantInt* c = dyn_cast<ConstantInt>(CondV->getOperand(i));
+        if (!c) {
+          Valid = false;
+          break;
+        }
+        Constant *C1 = CP1 ? CP1->getOperand(i) : Constant::getNullValue(EltTy);
+        Constant *C2 = CP2 ? CP2->getOperand(i) : Constant::getNullValue(EltTy);
+        Res[i] = c->getZExtValue() ? C1 : C2;
+      }
+      // If we were able to build the vector, return it
+      if (Valid) return ConstantVector::get(Res);
+    }
+  }
+
+
+  if (isa<UndefValue>(Cond)) {
+    if (isa<UndefValue>(V1)) return V1;
+    return V2;
+  }
   if (isa<UndefValue>(V1)) return V2;
   if (isa<UndefValue>(V2)) return V1;
-  if (isa<UndefValue>(Cond)) return V1;
   if (V1 == V2) return V1;
 
   if (ConstantExpr *TrueVal = dyn_cast<ConstantExpr>(V1)) {
@@ -832,7 +876,7 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1,
     Result.push_back(InElt);
   }
 
-  return ConstantVector::get(&Result[0], Result.size());
+  return ConstantVector::get(Result);
 }
 
 Constant *llvm::ConstantFoldExtractValueInstruction(Constant *Agg,
@@ -901,7 +945,7 @@ Constant *llvm::ConstantFoldInsertValueInstruction(Constant *Agg,
     }
     
     if (const StructType* ST = dyn_cast<StructType>(AggTy))
-      return ConstantStruct::get(ST->getContext(), Ops, ST->isPacked());
+      return ConstantStruct::get(ST, Ops);
     return ConstantArray::get(cast<ArrayType>(AggTy), Ops);
   }
   
@@ -932,7 +976,7 @@ Constant *llvm::ConstantFoldInsertValueInstruction(Constant *Agg,
     }
     
     if (const StructType *ST = dyn_cast<StructType>(AggTy))
-      return ConstantStruct::get(ST->getContext(), Ops, ST->isPacked());
+      return ConstantStruct::get(ST, Ops);
     return ConstantArray::get(cast<ArrayType>(AggTy), Ops);
   }
   
@@ -947,7 +991,7 @@ Constant *llvm::ConstantFoldInsertValueInstruction(Constant *Agg,
     }
     
     if (const StructType* ST = dyn_cast<StructType>(Agg->getType()))
-      return ConstantStruct::get(ST->getContext(), Ops, ST->isPacked());
+      return ConstantStruct::get(ST, Ops);
     return ConstantArray::get(cast<ArrayType>(Agg->getType()), Ops);
   }
 
@@ -973,20 +1017,38 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
     case Instruction::Add:
     case Instruction::Sub:
       return UndefValue::get(C1->getType());
-    case Instruction::Mul:
     case Instruction::And:
+      if (isa<UndefValue>(C1) && isa<UndefValue>(C2)) // undef & undef -> undef
+        return C1;
+      return Constant::getNullValue(C1->getType());   // undef & X -> 0
+    case Instruction::Mul: {
+      ConstantInt *CI;
+      // X * undef -> undef   if X is odd or undef
+      if (((CI = dyn_cast<ConstantInt>(C1)) && CI->getValue()[0]) ||
+          ((CI = dyn_cast<ConstantInt>(C2)) && CI->getValue()[0]) ||
+          (isa<UndefValue>(C1) && isa<UndefValue>(C2)))
+        return UndefValue::get(C1->getType());
+
+      // X * undef -> 0       otherwise
       return Constant::getNullValue(C1->getType());
+    }
     case Instruction::UDiv:
     case Instruction::SDiv:
+      // undef / 1 -> undef
+      if (Opcode == Instruction::UDiv || Opcode == Instruction::SDiv)
+        if (ConstantInt *CI2 = dyn_cast<ConstantInt>(C2))
+          if (CI2->isOne())
+            return C1;
+      // FALL THROUGH
     case Instruction::URem:
     case Instruction::SRem:
       if (!isa<UndefValue>(C2))                    // undef / X -> 0
         return Constant::getNullValue(C1->getType());
       return C2;                                   // X / undef -> undef
     case Instruction::Or:                          // X | undef -> -1
-      if (const VectorType *PTy = dyn_cast<VectorType>(C1->getType()))
-        return Constant::getAllOnesValue(PTy);
-      return Constant::getAllOnesValue(C1->getType());
+      if (isa<UndefValue>(C1) && isa<UndefValue>(C2)) // undef | undef -> undef
+        return C1;
+      return Constant::getAllOnesValue(C1->getType()); // undef | X -> ~0
     case Instruction::LShr:
       if (isa<UndefValue>(C2) && isa<UndefValue>(C1))
         return C1;                                  // undef lshr undef -> undef
@@ -1000,6 +1062,8 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
       else
         return C1;                                  // X ashr undef --> X
     case Instruction::Shl:
+      if (isa<UndefValue>(C2) && isa<UndefValue>(C1))
+        return C1;                                  // undef shl undef -> undef
       // undef << X -> 0   or   X << undef -> 0
       return Constant::getNullValue(C1->getType());
     }
@@ -1695,7 +1759,7 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2,
             // with a single zero index, it must be nonzero.
             assert(CE1->getNumOperands() == 2 &&
                    !CE1->getOperand(1)->isNullValue() &&
-                   "Suprising getelementptr!");
+                   "Surprising getelementptr!");
             return isSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
           } else {
             // If they are different globals, we don't know what the value is,
@@ -1790,7 +1854,9 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
   if (isa<UndefValue>(C1) || isa<UndefValue>(C2)) {
     // For EQ and NE, we can always pick a value for the undef to make the
     // predicate pass or fail, so we can return undef.
-    if (ICmpInst::isEquality(ICmpInst::Predicate(pred)))
+    // Also, if both operands are undef, we can return undef.
+    if (ICmpInst::isEquality(ICmpInst::Predicate(pred)) ||
+        (isa<UndefValue>(C1) && isa<UndefValue>(C2)))
       return UndefValue::get(ResultTy);
     // Otherwise, pick the same value as the non-undef operand, and fold
     // it to true or false.
@@ -1906,11 +1972,11 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
     // If we can constant fold the comparison of each element, constant fold
     // the whole vector comparison.
     SmallVector<Constant*, 4> ResElts;
-    for (unsigned i = 0, e = C1Elts.size(); i != e; ++i) {
-      // Compare the elements, producing an i1 result or constant expr.
+    // Compare the elements, producing an i1 result or constant expr.
+    for (unsigned i = 0, e = C1Elts.size(); i != e; ++i)
       ResElts.push_back(ConstantExpr::getCompare(pred, C1Elts[i], C2Elts[i]));
-    }
-    return ConstantVector::get(&ResElts[0], ResElts.size());
+
+    return ConstantVector::get(ResElts);
   }
 
   if (C1->getType()->isFloatingPointTy()) {
@@ -1958,7 +2024,7 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
       else if (pred == FCmpInst::FCMP_UGT || pred == FCmpInst::FCMP_OGT) 
         Result = 1;
       break;
-    case ICmpInst::ICMP_NE: // We know that C1 != C2
+    case FCmpInst::FCMP_ONE: // We know that C1 != C2
       // We can only partially decide this relation.
       if (pred == FCmpInst::FCMP_OEQ || pred == FCmpInst::FCMP_UEQ) 
         Result = 0;