Sink the return instruction collection until after we're done deleting
[oota-llvm.git] / lib / VMCore / ConstantFold.cpp
index abc019ff4e000d6dc97acbc6c076beacaeec6a28..b743287adf3941fd3a31f3fee8e9063d99672d01 100644 (file)
@@ -547,18 +547,17 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
   // If the cast operand is a constant vector, perform the cast by
   // operating on each element. In the cast of bitcasts, the element
   // count may be mismatched; don't attempt to handle that here.
-  if (ConstantVector *CV = dyn_cast<ConstantVector>(V))
-    if (DestTy->isVectorTy() &&
-        cast<VectorType>(DestTy)->getNumElements() ==
-        CV->getType()->getNumElements()) {
-      std::vector<Constant*> res;
-      VectorType *DestVecTy = cast<VectorType>(DestTy);
-      Type *DstEltTy = DestVecTy->getElementType();
-      for (unsigned i = 0, e = CV->getType()->getNumElements(); i != e; ++i)
-        res.push_back(ConstantExpr::getCast(opc,
-                                            CV->getOperand(i), DstEltTy));
-      return ConstantVector::get(res);
-    }
+  if ((isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) &&
+      DestTy->isVectorTy() &&
+      DestTy->getVectorNumElements() == V->getType()->getVectorNumElements()) {
+    SmallVector<Constant*, 16> res;
+    VectorType *DestVecTy = cast<VectorType>(DestTy);
+    Type *DstEltTy = DestVecTy->getElementType();
+    for (unsigned i = 0, e = V->getType()->getVectorNumElements(); i != e; ++i)
+      res.push_back(ConstantExpr::getCast(opc,
+                                          V->getAggregateElement(i), DstEltTy));
+    return ConstantVector::get(res);
+  }
 
   // We actually have to do a cast now. Perform the cast according to the
   // opcode specified.
@@ -694,7 +693,6 @@ Constant *llvm::ConstantFoldSelectInstruction(Constant *Cond,
   if (Cond->isNullValue()) return V2;
   if (Cond->isAllOnesValue()) return V1;
 
-  // FIXME: CDV Condition.
   // If the condition is a vector constant, fold the result elementwise.
   if (ConstantVector *CondV = dyn_cast<ConstantVector>(Cond)) {
     SmallVector<Constant*, 16> Result;
@@ -712,7 +710,6 @@ Constant *llvm::ConstantFoldSelectInstruction(Constant *Cond,
       return ConstantVector::get(Result);
   }
 
-
   if (isa<UndefValue>(Cond)) {
     if (isa<UndefValue>(V1)) return V1;
     return V2;
@@ -738,22 +735,19 @@ Constant *llvm::ConstantFoldSelectInstruction(Constant *Cond,
 Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val,
                                                       Constant *Idx) {
   if (isa<UndefValue>(Val))  // ee(undef, x) -> undef
-    return UndefValue::get(cast<VectorType>(Val->getType())->getElementType());
+    return UndefValue::get(Val->getType()->getVectorElementType());
   if (Val->isNullValue())  // ee(zero, x) -> zero
-    return Constant::getNullValue(
-                          cast<VectorType>(Val->getType())->getElementType());
-
-  if (ConstantVector *CVal = dyn_cast<ConstantVector>(Val)) {
-    if (ConstantInt *CIdx = dyn_cast<ConstantInt>(Idx)) {
-      uint64_t Index = CIdx->getZExtValue();
-      if (Index >= CVal->getNumOperands())
-        // ee({w,x,y,z}, wrong_value) -> undef
-        return UndefValue::get(cast<VectorType>(Val->getType())->getElementType());
-      return CVal->getOperand(CIdx->getZExtValue());
-    } else if (isa<UndefValue>(Idx)) {
-      // ee({w,x,y,z}, undef) -> undef
-      return UndefValue::get(cast<VectorType>(Val->getType())->getElementType());
-    }
+    return Constant::getNullValue(Val->getType()->getVectorElementType());
+  // ee({w,x,y,z}, undef) -> undef
+  if (isa<UndefValue>(Idx))
+    return UndefValue::get(Val->getType()->getVectorElementType());
+
+  if (ConstantInt *CIdx = dyn_cast<ConstantInt>(Idx)) {
+    uint64_t Index = CIdx->getZExtValue();
+    // ee({w,x,y,z}, wrong_value) -> undef
+    if (Index >= Val->getType()->getVectorNumElements())
+      return UndefValue::get(Val->getType()->getVectorElementType());
+    return Val->getAggregateElement(Index);
   }
   return 0;
 }
@@ -784,27 +778,30 @@ Constant *llvm::ConstantFoldInsertElementInstruction(Constant *Val,
 Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1,
                                                      Constant *V2,
                                                      Constant *Mask) {
+  unsigned MaskNumElts = Mask->getType()->getVectorNumElements();
+  Type *EltTy = V1->getType()->getVectorElementType();
+
   // Undefined shuffle mask -> undefined value.
-  if (isa<UndefValue>(Mask)) return UndefValue::get(V1->getType());
+  if (isa<UndefValue>(Mask))
+    return UndefValue::get(VectorType::get(EltTy, MaskNumElts));
 
-  unsigned MaskNumElts = cast<VectorType>(Mask->getType())->getNumElements();
-  unsigned SrcNumElts = cast<VectorType>(V1->getType())->getNumElements();
-  Type *EltTy = cast<VectorType>(V1->getType())->getElementType();
+  // Don't break the bitcode reader hack.
+  if (isa<ConstantExpr>(Mask)) return 0;
+  
+  unsigned SrcNumElts = V1->getType()->getVectorNumElements();
 
   // Loop over the shuffle mask, evaluating each element.
   SmallVector<Constant*, 32> Result;
   for (unsigned i = 0; i != MaskNumElts; ++i) {
-    Constant *InElt = Mask->getAggregateElement(i);
-    if (InElt == 0) return 0;
-
-    if (isa<UndefValue>(InElt)) {
+    int Elt = ShuffleVectorInst::getMaskValue(Mask, i);
+    if (Elt == -1) {
       Result.push_back(UndefValue::get(EltTy));
       continue;
     }
-    unsigned Elt = cast<ConstantInt>(InElt)->getZExtValue();
-    if (Elt >= SrcNumElts*2)
+    Constant *InElt;
+    if (unsigned(Elt) >= SrcNumElts*2)
       InElt = UndefValue::get(EltTy);
-    else if (Elt >= SrcNumElts)
+    else if (unsigned(Elt) >= SrcNumElts)
       InElt = V2->getAggregateElement(Elt - SrcNumElts);
     else
       InElt = V1->getAggregateElement(Elt);