Don't eliminate bitcast instructions that change the type of a pointer
[oota-llvm.git] / lib / VMCore / ConstantFold.cpp
index 06dc9dc89f7986a64cac9208efb7ffd7f680ac91..1f75fe57d61884421da7d33a4116bb81b6d195d4 100644 (file)
@@ -96,27 +96,29 @@ static Constant *FoldBitCast(Constant *V, const Type *DestTy) {
   // Check to see if we are casting a pointer to an aggregate to a pointer to
   // the first element.  If so, return the appropriate GEP instruction.
   if (const PointerType *PTy = dyn_cast<PointerType>(V->getType()))
-    if (const PointerType *DPTy = dyn_cast<PointerType>(DestTy)) {
-      SmallVector<Value*, 8> IdxList;
-      IdxList.push_back(Constant::getNullValue(Type::Int32Ty));
-      const Type *ElTy = PTy->getElementType();
-      while (ElTy != DPTy->getElementType()) {
-        if (const StructType *STy = dyn_cast<StructType>(ElTy)) {
-          if (STy->getNumElements() == 0) break;
-          ElTy = STy->getElementType(0);
-          IdxList.push_back(Constant::getNullValue(Type::Int32Ty));
-        } else if (const SequentialType *STy = dyn_cast<SequentialType>(ElTy)) {
-          if (isa<PointerType>(ElTy)) break;  // Can't index into pointers!
-          ElTy = STy->getElementType();
-          IdxList.push_back(IdxList[0]);
-        } else {
-          break;
+    if (const PointerType *DPTy = dyn_cast<PointerType>(DestTy))
+      if (PTy->getAddressSpace() == DPTy->getAddressSpace()) {
+        SmallVector<Value*, 8> IdxList;
+        IdxList.push_back(Constant::getNullValue(Type::Int32Ty));
+        const Type *ElTy = PTy->getElementType();
+        while (ElTy != DPTy->getElementType()) {
+          if (const StructType *STy = dyn_cast<StructType>(ElTy)) {
+            if (STy->getNumElements() == 0) break;
+            ElTy = STy->getElementType(0);
+            IdxList.push_back(Constant::getNullValue(Type::Int32Ty));
+          } else if (const SequentialType *STy = 
+                     dyn_cast<SequentialType>(ElTy)) {
+            if (isa<PointerType>(ElTy)) break;  // Can't index into pointers!
+            ElTy = STy->getElementType();
+            IdxList.push_back(IdxList[0]);
+          } else {
+            break;
+          }
         }
+        
+        if (ElTy == DPTy->getElementType())
+          return ConstantExpr::getGetElementPtr(V, &IdxList[0], IdxList.size());
       }
-      
-      if (ElTy == DPTy->getElementType())
-        return ConstantExpr::getGetElementPtr(V, &IdxList[0], IdxList.size());
-    }
   
   // Handle casts from one vector constant to another.  We know that the src 
   // and dest type have the same size (otherwise its an illegal cast).
@@ -170,8 +172,6 @@ static Constant *FoldBitCast(Constant *V, const Type *DestTy) {
 
 Constant *llvm::ConstantFoldCastInstruction(unsigned opc, const Constant *V,
                                             const Type *DestTy) {
-  const Type *SrcTy = V->getType();
-
   if (isa<UndefValue>(V)) {
     // zext(undef) = 0, because the top bits will be zero.
     // sext(undef) = 0, because the top bits will all be the same.
@@ -257,12 +257,11 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, const Constant *V,
     if (const ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
       APInt api = CI->getValue();
       const uint64_t zero[] = {0, 0};
-      uint32_t BitWidth = cast<IntegerType>(SrcTy)->getBitWidth();
       APFloat apf = APFloat(APInt(DestTy->getPrimitiveSizeInBits(),
                                   2, zero));
-      (void)apf.convertFromZeroExtendedInteger(api.getRawData(), BitWidth
-                                   opc==Instruction::SIToFP,
-                                   APFloat::rmNearestTiesToEven);
+      (void)apf.convertFromAPInt(api
+                                 opc==Instruction::SIToFP,
+                                 APFloat::rmNearestTiesToEven);
       return ConstantFP::get(DestTy, apf);
     }
     if (const ConstantVector *CV = dyn_cast<ConstantVector>(V)) {
@@ -477,9 +476,14 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
   // Handle UndefValue up front
   if (isa<UndefValue>(C1) || isa<UndefValue>(C2)) {
     switch (Opcode) {
+    case Instruction::Xor:
+      if (isa<UndefValue>(C1) && isa<UndefValue>(C2))
+        // Handle undef ^ undef -> 0 special case. This is a common
+        // idiom (misuse).
+        return Constant::getNullValue(C1->getType());
+      // Fallthrough
     case Instruction::Add:
     case Instruction::Sub:
-    case Instruction::Xor:
       return UndefValue::get(C1->getType());
     case Instruction::Mul:
     case Instruction::And:
@@ -660,25 +664,28 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
       case Instruction::Xor:
         return ConstantInt::get(C1V ^ C2V);
       case Instruction::Shl:
-        if (uint32_t shiftAmt = C2V.getZExtValue())
+        if (uint32_t shiftAmt = C2V.getZExtValue()) {
           if (shiftAmt < C1V.getBitWidth())
             return ConstantInt::get(C1V.shl(shiftAmt));
           else
             return UndefValue::get(C1->getType()); // too big shift is undef
+        }
         return const_cast<ConstantInt*>(CI1); // Zero shift is identity
       case Instruction::LShr:
-        if (uint32_t shiftAmt = C2V.getZExtValue())
+        if (uint32_t shiftAmt = C2V.getZExtValue()) {
           if (shiftAmt < C1V.getBitWidth())
             return ConstantInt::get(C1V.lshr(shiftAmt));
           else
             return UndefValue::get(C1->getType()); // too big shift is undef
+        }
         return const_cast<ConstantInt*>(CI1); // Zero shift is identity
       case Instruction::AShr:
-        if (uint32_t shiftAmt = C2V.getZExtValue())
+        if (uint32_t shiftAmt = C2V.getZExtValue()) {
           if (shiftAmt < C1V.getBitWidth())
             return ConstantInt::get(C1V.ashr(shiftAmt));
           else
             return UndefValue::get(C1->getType()); // too big shift is undef
+        }
         return const_cast<ConstantInt*>(CI1); // Zero shift is identity
       }
     }
@@ -1083,18 +1090,20 @@ static ICmpInst::Predicate evaluateICmpRelation(const Constant *V1,
             // Ok, we ran out of things they have in common.  If any leftovers
             // are non-zero then we have a difference, otherwise we are equal.
             for (; i < CE1->getNumOperands(); ++i)
-              if (!CE1->getOperand(i)->isNullValue())
+              if (!CE1->getOperand(i)->isNullValue()) {
                 if (isa<ConstantInt>(CE1->getOperand(i)))
                   return isSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
                 else
                   return ICmpInst::BAD_ICMP_PREDICATE; // Might be equal.
+              }
 
             for (; i < CE2->getNumOperands(); ++i)
-              if (!CE2->getOperand(i)->isNullValue())
+              if (!CE2->getOperand(i)->isNullValue()) {
                 if (isa<ConstantInt>(CE2->getOperand(i)))
                   return isSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
                 else
                   return ICmpInst::BAD_ICMP_PREDICATE; // Might be equal.
+              }
             return ICmpInst::ICMP_EQ;
           }
         }
@@ -1123,20 +1132,22 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
   if (C1->isNullValue()) {
     if (const GlobalValue *GV = dyn_cast<GlobalValue>(C2))
       // Don't try to evaluate aliases.  External weak GV can be null.
-      if (!isa<GlobalAlias>(GV) && !GV->hasExternalWeakLinkage())
+      if (!isa<GlobalAlias>(GV) && !GV->hasExternalWeakLinkage()) {
         if (pred == ICmpInst::ICMP_EQ)
           return ConstantInt::getFalse();
         else if (pred == ICmpInst::ICMP_NE)
           return ConstantInt::getTrue();
+      }
   // icmp eq/ne(GV,null) -> false/true
   } else if (C2->isNullValue()) {
     if (const GlobalValue *GV = dyn_cast<GlobalValue>(C1))
       // Don't try to evaluate aliases.  External weak GV can be null.
-      if (!isa<GlobalAlias>(GV) && !GV->hasExternalWeakLinkage())
+      if (!isa<GlobalAlias>(GV) && !GV->hasExternalWeakLinkage()) {
         if (pred == ICmpInst::ICMP_EQ)
           return ConstantInt::getFalse();
         else if (pred == ICmpInst::ICMP_NE)
           return ConstantInt::getTrue();
+      }
   }
 
   if (isa<ConstantInt>(C1) && isa<ConstantInt>(C2)) {