- Somehow I forgot about one / une.
[oota-llvm.git] / lib / Transforms / Scalar / ScalarReplAggregates.cpp
index 0cd10ef0c16d7d60a4995b60dc828b4a46e8a4fd..b319d8da9512fd33492be26793b2f47e7c27fc3c 100644 (file)
@@ -48,7 +48,7 @@ STATISTIC(NumGlobals,   "Number of allocas copied from constant global");
 namespace {
   struct VISIBILITY_HIDDEN SROA : public FunctionPass {
     static char ID; // Pass identification, replacement for typeid
-    explicit SROA(signed T = -1) : FunctionPass((intptr_t)&ID) {
+    explicit SROA(signed T = -1) : FunctionPass(&ID) {
       if (T == -1)
         SRThreshold = 128;
       else
@@ -178,6 +178,14 @@ bool SROA::performPromotion(Function &F) {
   return Changed;
 }
 
+/// getNumSAElements - Return the number of elements in the specific struct or
+/// array.
+static uint64_t getNumSAElements(const Type *T) {
+  if (const StructType *ST = dyn_cast<StructType>(T))
+    return ST->getNumElements();
+  return cast<ArrayType>(T)->getNumElements();
+}
+
 // performScalarRepl - This algorithm is a simple worklist driven algorithm,
 // which runs on all of the malloc/alloca instructions in the function, removing
 // them if they are only used by getelementptr instructions.
@@ -224,7 +232,10 @@ bool SROA::performScalarRepl(Function &F) {
         (isa<StructType>(AI->getAllocatedType()) ||
          isa<ArrayType>(AI->getAllocatedType())) &&
         AI->getAllocatedType()->isSized() &&
-        TD.getABITypeSize(AI->getAllocatedType()) < SRThreshold) {
+        // Do not promote any struct whose size is larger than "128" bytes.
+        TD.getABITypeSize(AI->getAllocatedType()) < SRThreshold &&
+        // Do not promote any struct into more than "32" separate vars.
+        getNumSAElements(AI->getAllocatedType()) < SRThreshold/4) {
       // Check that all of the users of the allocation are capable of being
       // transformed.
       switch (isSafeAllocaToScalarRepl(AI)) {
@@ -302,6 +313,43 @@ void SROA::DoScalarReplacement(AllocationInst *AI,
       continue;
     }
     
+    // Replace:
+    //   %res = load { i32, i32 }* %alloc
+    // with:
+    //   %load.0 = load i32* %alloc.0
+    //   %insert.0 insertvalue { i32, i32 } zeroinitializer, i32 %load.0, 0 
+    //   %load.1 = load i32* %alloc.1
+    //   %insert = insertvalue { i32, i32 } %insert.0, i32 %load.1, 1 
+    // (Also works for arrays instead of structs)
+    if (LoadInst *LI = dyn_cast<LoadInst>(User)) {
+      Value *Insert = UndefValue::get(LI->getType());
+      for (unsigned i = 0, e = ElementAllocas.size(); i != e; ++i) {
+        Value *Load = new LoadInst(ElementAllocas[i], "load", LI);
+        Insert = InsertValueInst::Create(Insert, Load, i, "insert", LI);
+      }
+      LI->replaceAllUsesWith(Insert);
+      LI->eraseFromParent();
+      continue;
+    }
+
+    // Replace:
+    //   store { i32, i32 } %val, { i32, i32 }* %alloc
+    // with:
+    //   %val.0 = extractvalue { i32, i32 } %val, 0 
+    //   store i32 %val.0, i32* %alloc.0
+    //   %val.1 = extractvalue { i32, i32 } %val, 1 
+    //   store i32 %val.1, i32* %alloc.1
+    // (Also works for arrays instead of structs)
+    if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
+      Value *Val = SI->getOperand(0);
+      for (unsigned i = 0, e = ElementAllocas.size(); i != e; ++i) {
+        Value *Extract = ExtractValueInst::Create(Val, i, Val->getName(), SI);
+        new StoreInst(Extract, ElementAllocas[i], SI);
+      }
+      SI->eraseFromParent();
+      continue;
+    }
+    
     GetElementPtrInst *GEPI = cast<GetElementPtrInst>(User);
     // We now know that the GEP is of the form: GEP <ptr>, 0, <cst>
     unsigned Idx =
@@ -440,6 +488,12 @@ void SROA::isSafeUseOfAllocation(Instruction *User, AllocationInst *AI,
   if (BitCastInst *C = dyn_cast<BitCastInst>(User))
     return isSafeUseOfBitCastedAllocation(C, AI, Info);
 
+  if (isa<LoadInst>(User))
+    return; // Loads (returning a first class aggregrate) are always rewritable
+
+  if (isa<StoreInst>(User) && User->getOperand(0) != AI)
+    return; // Store is ok if storing INTO the pointer, not storing the pointer
   GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(User);
   if (GEPI == 0)
     return MarkUnsafe(Info);
@@ -457,42 +511,12 @@ void SROA::isSafeUseOfAllocation(Instruction *User, AllocationInst *AI,
 
   bool IsAllZeroIndices = true;
   
-  // If this is a use of an array allocation, do a bit more checking for sanity.
+  // If the first index is a non-constant index into an array, see if we can
+  // handle it as a special case.
   if (const ArrayType *AT = dyn_cast<ArrayType>(*I)) {
-    uint64_t NumElements = AT->getNumElements();
-
-    if (ConstantInt *Idx = dyn_cast<ConstantInt>(I.getOperand())) {
-      IsAllZeroIndices &= Idx->isZero();
-      
-      // Check to make sure that index falls within the array.  If not,
-      // something funny is going on, so we won't do the optimization.
-      //
-      if (Idx->getZExtValue() >= NumElements)
-        return MarkUnsafe(Info);
-
-      // We cannot scalar repl this level of the array unless any array
-      // sub-indices are in-range constants.  In particular, consider:
-      // A[0][i].  We cannot know that the user isn't doing invalid things like
-      // allowing i to index an out-of-range subscript that accesses A[1].
-      //
-      // Scalar replacing *just* the outer index of the array is probably not
-      // going to be a win anyway, so just give up.
-      for (++I; I != E && (isa<ArrayType>(*I) || isa<VectorType>(*I)); ++I) {
-        uint64_t NumElements;
-        if (const ArrayType *SubArrayTy = dyn_cast<ArrayType>(*I))
-          NumElements = SubArrayTy->getNumElements();
-        else
-          NumElements = cast<VectorType>(*I)->getNumElements();
-        
-        ConstantInt *IdxVal = dyn_cast<ConstantInt>(I.getOperand());
-        if (!IdxVal) return MarkUnsafe(Info);
-        if (IdxVal->getZExtValue() >= NumElements)
-          return MarkUnsafe(Info);
-        IsAllZeroIndices &= IdxVal->isZero();
-      }
-      
-    } else {
+    if (!isa<ConstantInt>(I.getOperand())) {
       IsAllZeroIndices = 0;
+      uint64_t NumElements = AT->getNumElements();
       
       // If this is an array index and the index is not constant, we cannot
       // promote... that is unless the array has exactly one or two elements in
@@ -506,7 +530,42 @@ void SROA::isSafeUseOfAllocation(Instruction *User, AllocationInst *AI,
       return MarkUnsafe(Info);
     }
   }
+  bool hasVector = false;
+  
+  // Walk through the GEP type indices, checking the types that this indexes
+  // into.
+  for (; I != E; ++I) {
+    // Ignore struct elements, no extra checking needed for these.
+    if (isa<StructType>(*I))
+      continue;
+    
+    ConstantInt *IdxVal = dyn_cast<ConstantInt>(I.getOperand());
+    if (!IdxVal) return MarkUnsafe(Info);
 
+    // Are all indices still zero?
+    IsAllZeroIndices &= IdxVal->isZero();
+    
+    if (const ArrayType *AT = dyn_cast<ArrayType>(*I)) {
+      // This GEP indexes an array.  Verify that this is an in-range constant
+      // integer. Specifically, consider A[0][i]. We cannot know that the user
+      // isn't doing invalid things like allowing i to index an out-of-range
+      // subscript that accesses A[1].  Because of this, we have to reject SROA
+      // of any accesses into structs where any of the components are variables.
+      if (IdxVal->getZExtValue() >= AT->getNumElements())
+        return MarkUnsafe(Info);
+    }
+  
+    // Note if we've seen a vector type yet
+    hasVector |= isa<VectorType>(*I);
+    
+    // Don't SROA pointers into vectors, unless all indices are zero. When all
+    // indices are zero, we only consider this GEP as a bitcast, but will still
+    // not consider breaking up the vector.
+    if (hasVector && !IsAllZeroIndices)
+      return MarkUnsafe(Info);
+  }
+  
   // If there are any non-simple uses of this getelementptr, make sure to reject
   // them.
   return isSafeElementUse(GEPI, IsAllZeroIndices, AI, Info);
@@ -611,6 +670,11 @@ void SROA::RewriteBitCastUserOfAlloca(Instruction *BCInst, AllocationInst *AI,
       // It is likely that OtherPtr is a bitcast, if so, remove it.
       if (BitCastInst *BC = dyn_cast<BitCastInst>(OtherPtr))
         OtherPtr = BC->getOperand(0);
+      // All zero GEPs are effectively casts
+      if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(OtherPtr))
+        if (GEP->hasAllZeroIndices())
+          OtherPtr = GEP->getOperand(0);
+        
       if (ConstantExpr *BCE = dyn_cast<ConstantExpr>(OtherPtr))
         if (BCE->getOpcode() == Instruction::BitCast)
           OtherPtr = BCE->getOperand(0);
@@ -631,11 +695,9 @@ void SROA::RewriteBitCastUserOfAlloca(Instruction *BCInst, AllocationInst *AI,
       // If this is a memcpy/memmove, emit a GEP of the other element address.
       Value *OtherElt = 0;
       if (OtherPtr) {
-        Value *Idx[2];
-        Idx[0] = Zero;
-        Idx[1] = ConstantInt::get(Type::Int32Ty, i);
+        Value *Idx[2] = { Zero, ConstantInt::get(Type::Int32Ty, i) };
         OtherElt = GetElementPtrInst::Create(OtherPtr, Idx, Idx + 2,
-                                             OtherPtr->getNameStr()+"."+utostr(i),
+                                           OtherPtr->getNameStr()+"."+utostr(i),
                                              MI);
       }
 
@@ -643,7 +705,7 @@ void SROA::RewriteBitCastUserOfAlloca(Instruction *BCInst, AllocationInst *AI,
       const Type *EltTy =cast<PointerType>(EltPtr->getType())->getElementType();
       
       // If we got down to a scalar, insert a load or store as appropriate.
-      if (EltTy->isFirstClassType()) {
+      if (EltTy->isSingleValueType()) {
         if (isa<MemCpyInst>(MI) || isa<MemMoveInst>(MI)) {
           Value *Elt = new LoadInst(SROADest ? OtherElt : EltPtr, "tmp",
                                     MI);
@@ -737,8 +799,7 @@ void SROA::RewriteBitCastUserOfAlloca(Instruction *BCInst, AllocationInst *AI,
 
 /// HasPadding - Return true if the specified type has any structure or
 /// alignment padding, false otherwise.
-static bool HasPadding(const Type *Ty, const TargetData &TD,
-                       bool inPacked = false) {
+static bool HasPadding(const Type *Ty, const TargetData &TD) {
   if (const StructType *STy = dyn_cast<StructType>(Ty)) {
     const StructLayout *SL = TD.getStructLayout(STy);
     unsigned PrevFieldBitOffset = 0;
@@ -746,7 +807,7 @@ static bool HasPadding(const Type *Ty, const TargetData &TD,
       unsigned FieldBitOffset = SL->getElementOffsetInBits(i);
 
       // Padding in sub-elements?
-      if (HasPadding(STy->getElementType(i), TD, STy->isPacked()))
+      if (HasPadding(STy->getElementType(i), TD))
         return true;
 
       // Check to see if there is any padding between this element and the
@@ -770,12 +831,11 @@ static bool HasPadding(const Type *Ty, const TargetData &TD,
     }
 
   } else if (const ArrayType *ATy = dyn_cast<ArrayType>(Ty)) {
-    return HasPadding(ATy->getElementType(), TD, false);
+    return HasPadding(ATy->getElementType(), TD);
   } else if (const VectorType *VTy = dyn_cast<VectorType>(Ty)) {
-    return HasPadding(VTy->getElementType(), TD, false);
+    return HasPadding(VTy->getElementType(), TD);
   }
-  return inPacked ?
-    false : TD.getTypeSizeInBits(Ty) != TD.getABITypeSizeInBits(Ty);
+  return TD.getTypeSizeInBits(Ty) != TD.getABITypeSizeInBits(Ty);
 }
 
 /// isSafeStructAllocaToScalarRepl - Check to see if the specified allocation of
@@ -963,12 +1023,22 @@ const Type *SROA::CanConvertToScalar(Value *V, bool &IsNotTrivial) {
     Instruction *User = cast<Instruction>(*UI);
     
     if (LoadInst *LI = dyn_cast<LoadInst>(User)) {
+      // FIXME: Loads of a first class aggregrate value could be converted to a
+      // series of loads and insertvalues
+      if (!LI->getType()->isSingleValueType())
+        return 0;
+
       if (MergeInType(LI->getType(), UsedType, TD))
         return 0;
       
     } else if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
       // Storing the pointer, not into the value?
       if (SI->getOperand(0) == V) return 0;
+
+      // FIXME: Stores of a first class aggregrate value could be converted to a
+      // series of extractvalues and stores
+      if (!SI->getOperand(0)->getType()->isSingleValueType())
+        return 0;
       
       // NOTE: We could handle storing of FP imms into integers here!
       
@@ -1204,11 +1274,11 @@ Value *SROA::ConvertUsesOfLoadToScalar(LoadInst *LI, AllocaInst *NewAI,
   // We do this to support (f.e.) loads off the end of a structure where
   // only some bits are used.
   if (ShAmt > 0 && (unsigned)ShAmt < NTy->getBitWidth())
-    NV = BinaryOperator::createLShr(NV, 
+    NV = BinaryOperator::CreateLShr(NV, 
                                     ConstantInt::get(NV->getType(),ShAmt),
                                     LI->getName(), LI);
   else if (ShAmt < 0 && (unsigned)-ShAmt < NTy->getBitWidth())
-    NV = BinaryOperator::createShl(NV, 
+    NV = BinaryOperator::CreateShl(NV, 
                                    ConstantInt::get(NV->getType(),-ShAmt),
                                    LI->getName(), LI);
   
@@ -1308,12 +1378,12 @@ Value *SROA::ConvertUsesOfStoreToScalar(StoreInst *SI, AllocaInst *NewAI,
     // only some bits in the structure are set.
     APInt Mask(APInt::getLowBitsSet(DestWidth, SrcWidth));
     if (ShAmt > 0 && (unsigned)ShAmt < DestWidth) {
-      SV = BinaryOperator::createShl(SV, 
+      SV = BinaryOperator::CreateShl(SV, 
                                      ConstantInt::get(SV->getType(), ShAmt),
                                      SV->getName(), SI);
       Mask <<= ShAmt;
     } else if (ShAmt < 0 && (unsigned)-ShAmt < DestWidth) {
-      SV = BinaryOperator::createLShr(SV,
+      SV = BinaryOperator::CreateLShr(SV,
                                       ConstantInt::get(SV->getType(),-ShAmt),
                                       SV->getName(), SI);
       Mask = Mask.lshr(ShAmt);
@@ -1323,9 +1393,9 @@ Value *SROA::ConvertUsesOfStoreToScalar(StoreInst *SI, AllocaInst *NewAI,
     // in the new bits.
     if (SrcWidth != DestWidth) {
       assert(DestWidth > SrcWidth);
-      Old = BinaryOperator::createAnd(Old, ConstantInt::get(~Mask),
+      Old = BinaryOperator::CreateAnd(Old, ConstantInt::get(~Mask),
                                       Old->getName()+".mask", SI);
-      SV = BinaryOperator::createOr(Old, SV, SV->getName()+".ins", SI);
+      SV = BinaryOperator::CreateOr(Old, SV, SV->getName()+".ins", SI);
     }
   }
   return SV;