Change ScheduleDAG's SUnitMap from DenseMap<SDNode*, vector<SUnit*> >
[oota-llvm.git] / lib / VMCore / ConstantFold.cpp
index c15ce96959788fb83520187146d140122b70f631..069c99ac83f6c12fc0c6f78cac272d9666b4202c 100644 (file)
@@ -150,7 +150,7 @@ static Constant *FoldBitCast(Constant *V, const Type *DestTy) {
     if (DestTy->isFloatingPoint()) {
       assert((DestTy == Type::DoubleTy || DestTy == Type::FloatTy) && 
              "Unknown FP type!");
-      return ConstantFP::get(DestTy, APFloat(CI->getValue()));
+      return ConstantFP::get(APFloat(CI->getValue()));
     }
     // Otherwise, can't fold this (vector?)
     return 0;
@@ -220,7 +220,7 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, const Constant *V,
                   DestTy == Type::FP128Ty ? APFloat::IEEEquad :
                   APFloat::Bogus,
                   APFloat::rmNearestTiesToEven);
-      return ConstantFP::get(DestTy, Val);
+      return ConstantFP::get(Val);
     }
     return 0; // Can't fold.
   case Instruction::FPToUI: 
@@ -262,7 +262,7 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, const Constant *V,
       (void)apf.convertFromAPInt(api, 
                                  opc==Instruction::SIToFP,
                                  APFloat::rmNearestTiesToEven);
-      return ConstantFP::get(DestTy, apf);
+      return ConstantFP::get(apf);
     }
     if (const ConstantVector *CV = dyn_cast<ConstantVector>(V)) {
       std::vector<Constant*> res;
@@ -332,10 +332,10 @@ Constant *llvm::ConstantFoldExtractElementInstruction(const Constant *Val,
   
   if (const ConstantVector *CVal = dyn_cast<ConstantVector>(Val)) {
     if (const ConstantInt *CIdx = dyn_cast<ConstantInt>(Idx)) {
-      return const_cast<Constant*>(CVal->getOperand(CIdx->getZExtValue()));
+      return CVal->getOperand(CIdx->getZExtValue());
     } else if (isa<UndefValue>(Idx)) {
       // ee({w,x,y,z}, undef) -> w (an arbitrary value).
-      return const_cast<Constant*>(CVal->getOperand(0));
+      return CVal->getOperand(0);
     }
   }
   return 0;
@@ -394,6 +394,7 @@ Constant *llvm::ConstantFoldInsertElementInstruction(const Constant *Val,
     }
     return ConstantVector::get(Ops);
   }
+
   return 0;
 }
 
@@ -401,7 +402,7 @@ Constant *llvm::ConstantFoldInsertElementInstruction(const Constant *Val,
 /// return the specified element value.  Otherwise return null.
 static Constant *GetVectorElement(const Constant *C, unsigned EltNo) {
   if (const ConstantVector *CV = dyn_cast<ConstantVector>(C))
-    return const_cast<Constant*>(CV->getOperand(EltNo));
+    return CV->getOperand(EltNo);
   
   const Type *EltTy = cast<VectorType>(C->getType())->getElementType();
   if (isa<ConstantAggregateZero>(C))
@@ -447,6 +448,115 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(const Constant *V1,
   return ConstantVector::get(&Result[0], Result.size());
 }
 
+Constant *llvm::ConstantFoldExtractValueInstruction(const Constant *Agg,
+                                                    const unsigned *Idxs,
+                                                    unsigned NumIdx) {
+  // Base case: no indices, so return the entire value.
+  if (NumIdx == 0)
+    return const_cast<Constant *>(Agg);
+
+  if (isa<UndefValue>(Agg))  // ev(undef, x) -> undef
+    return UndefValue::get(ExtractValueInst::getIndexedType(Agg->getType(),
+                                                            Idxs,
+                                                            Idxs + NumIdx));
+
+  if (isa<ConstantAggregateZero>(Agg))  // ev(0, x) -> 0
+    return
+      Constant::getNullValue(ExtractValueInst::getIndexedType(Agg->getType(),
+                                                              Idxs,
+                                                              Idxs + NumIdx));
+
+  // Otherwise recurse.
+  return ConstantFoldExtractValueInstruction(Agg->getOperand(*Idxs),
+                                             Idxs+1, NumIdx-1);
+}
+
+Constant *llvm::ConstantFoldInsertValueInstruction(const Constant *Agg,
+                                                   const Constant *Val,
+                                                   const unsigned *Idxs,
+                                                   unsigned NumIdx) {
+  // Base case: no indices, so replace the entire value.
+  if (NumIdx == 0)
+    return const_cast<Constant *>(Val);
+
+  if (isa<UndefValue>(Agg)) {
+    // Insertion of constant into aggregate undef
+    // Optimize away insertion of undef
+    if (isa<UndefValue>(Val))
+      return const_cast<Constant*>(Agg);
+    // Otherwise break the aggregate undef into multiple undefs and do
+    // the insertion
+    const CompositeType *AggTy = cast<CompositeType>(Agg->getType());
+    unsigned numOps;
+    if (const ArrayType *AR = dyn_cast<ArrayType>(AggTy))
+      numOps = AR->getNumElements();
+    else
+      numOps = cast<StructType>(AggTy)->getNumElements();
+    std::vector<Constant*> Ops(numOps); 
+    for (unsigned i = 0; i < numOps; ++i) {
+      const Type *MemberTy = AggTy->getTypeAtIndex(i);
+      const Constant *Op =
+        (*Idxs == i) ?
+        ConstantFoldInsertValueInstruction(UndefValue::get(MemberTy),
+                                           Val, Idxs+1, NumIdx-1) :
+        UndefValue::get(MemberTy);
+      Ops[i] = const_cast<Constant*>(Op);
+    }
+    if (isa<StructType>(AggTy))
+      return ConstantStruct::get(Ops);
+    else
+      return ConstantArray::get(cast<ArrayType>(AggTy), Ops);
+  }
+  if (isa<ConstantAggregateZero>(Agg)) {
+    // Insertion of constant into aggregate zero
+    // Optimize away insertion of zero
+    if (Val->isNullValue())
+      return const_cast<Constant*>(Agg);
+    // Otherwise break the aggregate zero into multiple zeros and do
+    // the insertion
+    const CompositeType *AggTy = cast<CompositeType>(Agg->getType());
+    unsigned numOps;
+    if (const ArrayType *AR = dyn_cast<ArrayType>(AggTy))
+      numOps = AR->getNumElements();
+    else
+      numOps = cast<StructType>(AggTy)->getNumElements();
+    std::vector<Constant*> Ops(numOps);
+    for (unsigned i = 0; i < numOps; ++i) {
+      const Type *MemberTy = AggTy->getTypeAtIndex(i);
+      const Constant *Op =
+        (*Idxs == i) ?
+        ConstantFoldInsertValueInstruction(Constant::getNullValue(MemberTy),
+                                           Val, Idxs+1, NumIdx-1) :
+        Constant::getNullValue(MemberTy);
+      Ops[i] = const_cast<Constant*>(Op);
+    }
+    if (isa<StructType>(AggTy))
+      return ConstantStruct::get(Ops);
+    else
+      return ConstantArray::get(cast<ArrayType>(AggTy), Ops);
+  }
+  if (isa<ConstantStruct>(Agg) || isa<ConstantArray>(Agg)) {
+    // Insertion of constant into aggregate constant
+    std::vector<Constant*> Ops(Agg->getNumOperands());
+    for (unsigned i = 0; i < Agg->getNumOperands(); ++i) {
+      const Constant *Op =
+        (*Idxs == i) ?
+        ConstantFoldInsertValueInstruction(Agg->getOperand(i),
+                                           Val, Idxs+1, NumIdx-1) :
+        Agg->getOperand(i);
+      Ops[i] = const_cast<Constant*>(Op);
+    }
+    Constant *C;
+    if (isa<StructType>(Agg->getType()))
+      C = ConstantStruct::get(Ops);
+    else
+      C = ConstantArray::get(cast<ArrayType>(Agg->getType()), Ops);
+    return C;
+  }
+
+  return 0;
+}
+
 /// EvalVectorOp - Given two vector constants and a function pointer, apply the
 /// function pointer to each element pair, producing a new ConstantVector
 /// constant. Either or both of V1 and V2 may be NULL, meaning a
@@ -548,8 +658,8 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
       if (CI2->isAllOnesValue())
         return const_cast<Constant*>(C1);                     // X & -1 == X
       
-      // (zext i32 to i64) & 4294967295 -> (zext i32 to i64)
       if (const ConstantExpr *CE1 = dyn_cast<ConstantExpr>(C1)) {
+        // (zext i32 to i64) & 4294967295 -> (zext i32 to i64)
         if (CE1->getOpcode() == Instruction::ZExt) {
           unsigned DstWidth = CI2->getType()->getBitWidth();
           unsigned SrcWidth =
@@ -559,16 +669,25 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
             return const_cast<Constant*>(C1);
         }
         
+        // If and'ing the address of a global with a constant, fold it.
         if (CE1->getOpcode() == Instruction::PtrToInt && 
             isa<GlobalValue>(CE1->getOperand(0))) {
-          GlobalValue *CPR = cast<GlobalValue>(CE1->getOperand(0));
+          GlobalValue *GV = cast<GlobalValue>(CE1->getOperand(0));
         
-          // Functions are at least 4-byte aligned.  If and'ing the address of a
-          // function with a constant < 4, fold it to zero.
-          if (const ConstantInt *CI = dyn_cast<ConstantInt>(C2))
-            if (CI->getValue().ult(APInt(CI->getType()->getBitWidth(),4)) && 
-                isa<Function>(CPR))
-              return Constant::getNullValue(CI->getType());
+          // Functions are at least 4-byte aligned.
+          unsigned GVAlign = GV->getAlignment();
+          if (isa<Function>(GV))
+            GVAlign = std::max(GVAlign, 4U);
+          
+          if (GVAlign > 1) {
+            unsigned DstWidth = CI2->getType()->getBitWidth();
+            unsigned SrcWidth = std::min(DstWidth, Log2_32(GVAlign));
+            APInt BitsNotSet(APInt::getLowBitsSet(DstWidth, SrcWidth));
+
+            // If checking bits we know are clear, return zero.
+            if ((CI2->getValue() & BitsNotSet) == CI2->getValue())
+              return Constant::getNullValue(CI2->getType());
+          }
         }
       }
       break;
@@ -590,44 +709,12 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
     }
   }
   
-  if (isa<ConstantExpr>(C1)) {
-    // There are many possible foldings we could do here.  We should probably
-    // at least fold add of a pointer with an integer into the appropriate
-    // getelementptr.  This will improve alias analysis a bit.
-  } else if (isa<ConstantExpr>(C2)) {
-    // If C2 is a constant expr and C1 isn't, flop them around and fold the
-    // other way if possible.
-    switch (Opcode) {
-    case Instruction::Add:
-    case Instruction::Mul:
-    case Instruction::And:
-    case Instruction::Or:
-    case Instruction::Xor:
-      // No change of opcode required.
-      return ConstantFoldBinaryInstruction(Opcode, C2, C1);
-
-    case Instruction::Shl:
-    case Instruction::LShr:
-    case Instruction::AShr:
-    case Instruction::Sub:
-    case Instruction::SDiv:
-    case Instruction::UDiv:
-    case Instruction::FDiv:
-    case Instruction::URem:
-    case Instruction::SRem:
-    case Instruction::FRem:
-    default:  // These instructions cannot be flopped around.
-      return 0;
-    }
-  }
-
-  // At this point we know neither constant is an UndefValue nor a ConstantExpr
-  // so look at directly computing the value.
+  // At this point we know neither constant is an UndefValue.
   if (const ConstantInt *CI1 = dyn_cast<ConstantInt>(C1)) {
     if (const ConstantInt *CI2 = dyn_cast<ConstantInt>(C2)) {
       using namespace APIntOps;
-      APInt C1V = CI1->getValue();
-      APInt C2V = CI2->getValue();
+      const APInt &C1V = CI1->getValue();
+      const APInt &C2V = CI2->getValue();
       switch (Opcode) {
       default:
         break;
@@ -663,30 +750,27 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
         return ConstantInt::get(C1V | C2V);
       case Instruction::Xor:
         return ConstantInt::get(C1V ^ C2V);
-      case Instruction::Shl:
-        if (uint32_t shiftAmt = C2V.getZExtValue()) {
-          if (shiftAmt < C1V.getBitWidth())
-            return ConstantInt::get(C1V.shl(shiftAmt));
-          else
-            return UndefValue::get(C1->getType()); // too big shift is undef
-        }
-        return const_cast<ConstantInt*>(CI1); // Zero shift is identity
-      case Instruction::LShr:
-        if (uint32_t shiftAmt = C2V.getZExtValue()) {
-          if (shiftAmt < C1V.getBitWidth())
-            return ConstantInt::get(C1V.lshr(shiftAmt));
-          else
-            return UndefValue::get(C1->getType()); // too big shift is undef
-        }
-        return const_cast<ConstantInt*>(CI1); // Zero shift is identity
-      case Instruction::AShr:
-        if (uint32_t shiftAmt = C2V.getZExtValue()) {
-          if (shiftAmt < C1V.getBitWidth())
-            return ConstantInt::get(C1V.ashr(shiftAmt));
-          else
-            return UndefValue::get(C1->getType()); // too big shift is undef
-        }
-        return const_cast<ConstantInt*>(CI1); // Zero shift is identity
+      case Instruction::Shl: {
+        uint32_t shiftAmt = C2V.getZExtValue();
+        if (shiftAmt < C1V.getBitWidth())
+          return ConstantInt::get(C1V.shl(shiftAmt));
+        else
+          return UndefValue::get(C1->getType()); // too big shift is undef
+      }
+      case Instruction::LShr: {
+        uint32_t shiftAmt = C2V.getZExtValue();
+        if (shiftAmt < C1V.getBitWidth())
+          return ConstantInt::get(C1V.lshr(shiftAmt));
+        else
+          return UndefValue::get(C1->getType()); // too big shift is undef
+      }
+      case Instruction::AShr: {
+        uint32_t shiftAmt = C2V.getZExtValue();
+        if (shiftAmt < C1V.getBitWidth())
+          return ConstantInt::get(C1V.ashr(shiftAmt));
+        else
+          return UndefValue::get(C1->getType()); // too big shift is undef
+      }
       }
     }
   } else if (const ConstantFP *CFP1 = dyn_cast<ConstantFP>(C1)) {
@@ -694,30 +778,34 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
       APFloat C1V = CFP1->getValueAPF();
       APFloat C2V = CFP2->getValueAPF();
       APFloat C3V = C1V;  // copy for modification
-      bool isDouble = CFP1->getType()==Type::DoubleTy;
       switch (Opcode) {
       default:                   
         break;
       case Instruction::Add:
         (void)C3V.add(C2V, APFloat::rmNearestTiesToEven);
-        return ConstantFP::get(CFP1->getType(), C3V);
+        return ConstantFP::get(C3V);
       case Instruction::Sub:     
         (void)C3V.subtract(C2V, APFloat::rmNearestTiesToEven);
-        return ConstantFP::get(CFP1->getType(), C3V);
+        return ConstantFP::get(C3V);
       case Instruction::Mul:
         (void)C3V.multiply(C2V, APFloat::rmNearestTiesToEven);
-        return ConstantFP::get(CFP1->getType(), C3V);
+        return ConstantFP::get(C3V);
       case Instruction::FDiv:
         (void)C3V.divide(C2V, APFloat::rmNearestTiesToEven);
-        return ConstantFP::get(CFP1->getType(), C3V);
+        return ConstantFP::get(C3V);
       case Instruction::FRem:
-        if (C2V.isZero())
+        if (C2V.isZero()) {
           // IEEE 754, Section 7.1, #5
-          return ConstantFP::get(CFP1->getType(), isDouble ?
-                            APFloat(std::numeric_limits<double>::quiet_NaN()) :
-                            APFloat(std::numeric_limits<float>::quiet_NaN()));
+          if (CFP1->getType() == Type::DoubleTy)
+            return ConstantFP::get(APFloat(std::numeric_limits<double>::
+                                           quiet_NaN()));
+          if (CFP1->getType() == Type::FloatTy)
+            return ConstantFP::get(APFloat(std::numeric_limits<float>::
+                                           quiet_NaN()));
+          break;
+        }
         (void)C3V.mod(C2V, APFloat::rmNearestTiesToEven);
-        return ConstantFP::get(CFP1->getType(), C3V);
+        return ConstantFP::get(C3V);
       }
     }
   } else if (const VectorType *VTy = dyn_cast<VectorType>(C1->getType())) {
@@ -756,7 +844,38 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
     }
   }
 
-  // We don't know how to fold this
+  if (isa<ConstantExpr>(C1)) {
+    // There are many possible foldings we could do here.  We should probably
+    // at least fold add of a pointer with an integer into the appropriate
+    // getelementptr.  This will improve alias analysis a bit.
+  } else if (isa<ConstantExpr>(C2)) {
+    // If C2 is a constant expr and C1 isn't, flop them around and fold the
+    // other way if possible.
+    switch (Opcode) {
+    case Instruction::Add:
+    case Instruction::Mul:
+    case Instruction::And:
+    case Instruction::Or:
+    case Instruction::Xor:
+      // No change of opcode required.
+      return ConstantFoldBinaryInstruction(Opcode, C2, C1);
+      
+    case Instruction::Shl:
+    case Instruction::LShr:
+    case Instruction::AShr:
+    case Instruction::Sub:
+    case Instruction::SDiv:
+    case Instruction::UDiv:
+    case Instruction::FDiv:
+    case Instruction::URem:
+    case Instruction::SRem:
+    case Instruction::FRem:
+    default:  // These instructions cannot be flopped around.
+      break;
+    }
+  }
+  
+  // We don't know how to fold this.
   return 0;
 }
 
@@ -1213,9 +1332,9 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
     if (const ConstantVector *CP2 = dyn_cast<ConstantVector>(C2)) {
       if (pred == FCmpInst::FCMP_OEQ || pred == FCmpInst::FCMP_UEQ) {
         for (unsigned i = 0, e = CP1->getNumOperands(); i != e; ++i) {
-          Constant *C= ConstantExpr::getFCmp(FCmpInst::FCMP_OEQ,
-              const_cast<Constant*>(CP1->getOperand(i)),
-              const_cast<Constant*>(CP2->getOperand(i)));
+          Constant *C = ConstantExpr::getFCmp(FCmpInst::FCMP_OEQ,
+                                              CP1->getOperand(i),
+                                              CP2->getOperand(i));
           if (ConstantInt *CB = dyn_cast<ConstantInt>(C))
             return CB;
         }
@@ -1224,8 +1343,8 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
       } else if (pred == ICmpInst::ICMP_EQ) {
         for (unsigned i = 0, e = CP1->getNumOperands(); i != e; ++i) {
           Constant *C = ConstantExpr::getICmp(ICmpInst::ICMP_EQ,
-              const_cast<Constant*>(CP1->getOperand(i)),
-              const_cast<Constant*>(CP2->getOperand(i)));
+                                              CP1->getOperand(i),
+                                              CP2->getOperand(i));
           if (ConstantInt *CB = dyn_cast<ConstantInt>(C))
             return CB;
         }
@@ -1399,8 +1518,7 @@ Constant *llvm::ConstantFoldGetElementPtr(const Constant *C,
     const PointerType *Ptr = cast<PointerType>(C->getType());
     const Type *Ty = GetElementPtrInst::getIndexedType(Ptr,
                                                        (Value **)Idxs,
-                                                       (Value **)Idxs+NumIdx,
-                                                       true);
+                                                       (Value **)Idxs+NumIdx);
     assert(Ty != 0 && "Invalid indices for GEP!");
     return UndefValue::get(PointerType::get(Ty, Ptr->getAddressSpace()));
   }
@@ -1417,8 +1535,7 @@ Constant *llvm::ConstantFoldGetElementPtr(const Constant *C,
       const PointerType *Ptr = cast<PointerType>(C->getType());
       const Type *Ty = GetElementPtrInst::getIndexedType(Ptr,
                                                          (Value**)Idxs,
-                                                         (Value**)Idxs+NumIdx,
-                                                         true);
+                                                         (Value**)Idxs+NumIdx);
       assert(Ty != 0 && "Invalid indices for GEP!");
       return 
         ConstantPointerNull::get(PointerType::get(Ty,Ptr->getAddressSpace()));