Change ScheduleDAG's SUnitMap from DenseMap<SDNode*, vector<SUnit*> >
[oota-llvm.git] / lib / VMCore / ConstantFold.cpp
index 50fbe1a00c4c8915635fabca1033c68cf345d89d..069c99ac83f6c12fc0c6f78cac272d9666b4202c 100644 (file)
@@ -332,10 +332,10 @@ Constant *llvm::ConstantFoldExtractElementInstruction(const Constant *Val,
   
   if (const ConstantVector *CVal = dyn_cast<ConstantVector>(Val)) {
     if (const ConstantInt *CIdx = dyn_cast<ConstantInt>(Idx)) {
-      return const_cast<Constant*>(CVal->getOperand(CIdx->getZExtValue()));
+      return CVal->getOperand(CIdx->getZExtValue());
     } else if (isa<UndefValue>(Idx)) {
       // ee({w,x,y,z}, undef) -> w (an arbitrary value).
-      return const_cast<Constant*>(CVal->getOperand(0));
+      return CVal->getOperand(0);
     }
   }
   return 0;
@@ -394,6 +394,7 @@ Constant *llvm::ConstantFoldInsertElementInstruction(const Constant *Val,
     }
     return ConstantVector::get(Ops);
   }
+
   return 0;
 }
 
@@ -401,7 +402,7 @@ Constant *llvm::ConstantFoldInsertElementInstruction(const Constant *Val,
 /// return the specified element value.  Otherwise return null.
 static Constant *GetVectorElement(const Constant *C, unsigned EltNo) {
   if (const ConstantVector *CV = dyn_cast<ConstantVector>(C))
-    return const_cast<Constant*>(CV->getOperand(EltNo));
+    return CV->getOperand(EltNo);
   
   const Type *EltTy = cast<VectorType>(C->getType())->getElementType();
   if (isa<ConstantAggregateZero>(C))
@@ -447,6 +448,115 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(const Constant *V1,
   return ConstantVector::get(&Result[0], Result.size());
 }
 
+Constant *llvm::ConstantFoldExtractValueInstruction(const Constant *Agg,
+                                                    const unsigned *Idxs,
+                                                    unsigned NumIdx) {
+  // Base case: no indices, so return the entire value.
+  if (NumIdx == 0)
+    return const_cast<Constant *>(Agg);
+
+  if (isa<UndefValue>(Agg))  // ev(undef, x) -> undef
+    return UndefValue::get(ExtractValueInst::getIndexedType(Agg->getType(),
+                                                            Idxs,
+                                                            Idxs + NumIdx));
+
+  if (isa<ConstantAggregateZero>(Agg))  // ev(0, x) -> 0
+    return
+      Constant::getNullValue(ExtractValueInst::getIndexedType(Agg->getType(),
+                                                              Idxs,
+                                                              Idxs + NumIdx));
+
+  // Otherwise recurse.
+  return ConstantFoldExtractValueInstruction(Agg->getOperand(*Idxs),
+                                             Idxs+1, NumIdx-1);
+}
+
+Constant *llvm::ConstantFoldInsertValueInstruction(const Constant *Agg,
+                                                   const Constant *Val,
+                                                   const unsigned *Idxs,
+                                                   unsigned NumIdx) {
+  // Base case: no indices, so replace the entire value.
+  if (NumIdx == 0)
+    return const_cast<Constant *>(Val);
+
+  if (isa<UndefValue>(Agg)) {
+    // Insertion of constant into aggregate undef
+    // Optimize away insertion of undef
+    if (isa<UndefValue>(Val))
+      return const_cast<Constant*>(Agg);
+    // Otherwise break the aggregate undef into multiple undefs and do
+    // the insertion
+    const CompositeType *AggTy = cast<CompositeType>(Agg->getType());
+    unsigned numOps;
+    if (const ArrayType *AR = dyn_cast<ArrayType>(AggTy))
+      numOps = AR->getNumElements();
+    else
+      numOps = cast<StructType>(AggTy)->getNumElements();
+    std::vector<Constant*> Ops(numOps); 
+    for (unsigned i = 0; i < numOps; ++i) {
+      const Type *MemberTy = AggTy->getTypeAtIndex(i);
+      const Constant *Op =
+        (*Idxs == i) ?
+        ConstantFoldInsertValueInstruction(UndefValue::get(MemberTy),
+                                           Val, Idxs+1, NumIdx-1) :
+        UndefValue::get(MemberTy);
+      Ops[i] = const_cast<Constant*>(Op);
+    }
+    if (isa<StructType>(AggTy))
+      return ConstantStruct::get(Ops);
+    else
+      return ConstantArray::get(cast<ArrayType>(AggTy), Ops);
+  }
+  if (isa<ConstantAggregateZero>(Agg)) {
+    // Insertion of constant into aggregate zero
+    // Optimize away insertion of zero
+    if (Val->isNullValue())
+      return const_cast<Constant*>(Agg);
+    // Otherwise break the aggregate zero into multiple zeros and do
+    // the insertion
+    const CompositeType *AggTy = cast<CompositeType>(Agg->getType());
+    unsigned numOps;
+    if (const ArrayType *AR = dyn_cast<ArrayType>(AggTy))
+      numOps = AR->getNumElements();
+    else
+      numOps = cast<StructType>(AggTy)->getNumElements();
+    std::vector<Constant*> Ops(numOps);
+    for (unsigned i = 0; i < numOps; ++i) {
+      const Type *MemberTy = AggTy->getTypeAtIndex(i);
+      const Constant *Op =
+        (*Idxs == i) ?
+        ConstantFoldInsertValueInstruction(Constant::getNullValue(MemberTy),
+                                           Val, Idxs+1, NumIdx-1) :
+        Constant::getNullValue(MemberTy);
+      Ops[i] = const_cast<Constant*>(Op);
+    }
+    if (isa<StructType>(AggTy))
+      return ConstantStruct::get(Ops);
+    else
+      return ConstantArray::get(cast<ArrayType>(AggTy), Ops);
+  }
+  if (isa<ConstantStruct>(Agg) || isa<ConstantArray>(Agg)) {
+    // Insertion of constant into aggregate constant
+    std::vector<Constant*> Ops(Agg->getNumOperands());
+    for (unsigned i = 0; i < Agg->getNumOperands(); ++i) {
+      const Constant *Op =
+        (*Idxs == i) ?
+        ConstantFoldInsertValueInstruction(Agg->getOperand(i),
+                                           Val, Idxs+1, NumIdx-1) :
+        Agg->getOperand(i);
+      Ops[i] = const_cast<Constant*>(Op);
+    }
+    Constant *C;
+    if (isa<StructType>(Agg->getType()))
+      C = ConstantStruct::get(Ops);
+    else
+      C = ConstantArray::get(cast<ArrayType>(Agg->getType()), Ops);
+    return C;
+  }
+
+  return 0;
+}
+
 /// EvalVectorOp - Given two vector constants and a function pointer, apply the
 /// function pointer to each element pair, producing a new ConstantVector
 /// constant. Either or both of V1 and V2 may be NULL, meaning a
@@ -571,7 +681,7 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
           
           if (GVAlign > 1) {
             unsigned DstWidth = CI2->getType()->getBitWidth();
-            unsigned SrcWidth = std::min(SrcWidth, Log2_32(GVAlign));
+            unsigned SrcWidth = std::min(DstWidth, Log2_32(GVAlign));
             APInt BitsNotSet(APInt::getLowBitsSet(DstWidth, SrcWidth));
 
             // If checking bits we know are clear, return zero.
@@ -1222,9 +1332,9 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
     if (const ConstantVector *CP2 = dyn_cast<ConstantVector>(C2)) {
       if (pred == FCmpInst::FCMP_OEQ || pred == FCmpInst::FCMP_UEQ) {
         for (unsigned i = 0, e = CP1->getNumOperands(); i != e; ++i) {
-          Constant *C= ConstantExpr::getFCmp(FCmpInst::FCMP_OEQ,
-              const_cast<Constant*>(CP1->getOperand(i)),
-              const_cast<Constant*>(CP2->getOperand(i)));
+          Constant *C = ConstantExpr::getFCmp(FCmpInst::FCMP_OEQ,
+                                              CP1->getOperand(i),
+                                              CP2->getOperand(i));
           if (ConstantInt *CB = dyn_cast<ConstantInt>(C))
             return CB;
         }
@@ -1233,8 +1343,8 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
       } else if (pred == ICmpInst::ICMP_EQ) {
         for (unsigned i = 0, e = CP1->getNumOperands(); i != e; ++i) {
           Constant *C = ConstantExpr::getICmp(ICmpInst::ICMP_EQ,
-              const_cast<Constant*>(CP1->getOperand(i)),
-              const_cast<Constant*>(CP2->getOperand(i)));
+                                              CP1->getOperand(i),
+                                              CP2->getOperand(i));
           if (ConstantInt *CB = dyn_cast<ConstantInt>(C))
             return CB;
         }
@@ -1408,8 +1518,7 @@ Constant *llvm::ConstantFoldGetElementPtr(const Constant *C,
     const PointerType *Ptr = cast<PointerType>(C->getType());
     const Type *Ty = GetElementPtrInst::getIndexedType(Ptr,
                                                        (Value **)Idxs,
-                                                       (Value **)Idxs+NumIdx,
-                                                       true);
+                                                       (Value **)Idxs+NumIdx);
     assert(Ty != 0 && "Invalid indices for GEP!");
     return UndefValue::get(PointerType::get(Ty, Ptr->getAddressSpace()));
   }
@@ -1426,8 +1535,7 @@ Constant *llvm::ConstantFoldGetElementPtr(const Constant *C,
       const PointerType *Ptr = cast<PointerType>(C->getType());
       const Type *Ty = GetElementPtrInst::getIndexedType(Ptr,
                                                          (Value**)Idxs,
-                                                         (Value**)Idxs+NumIdx,
-                                                         true);
+                                                         (Value**)Idxs+NumIdx);
       assert(Ty != 0 && "Invalid indices for GEP!");
       return 
         ConstantPointerNull::get(PointerType::get(Ty,Ptr->getAddressSpace()));