Follow-up fix to r165928: handle memset rewriting for widened integers,
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCalls.cpp
index cbe1ca4ddcec08a0e74e9113ffd8864907160924..44ddf3be34e63c9b221873e638a34da655d06396 100644 (file)
@@ -13,7 +13,7 @@
 
 #include "InstCombine.h"
 #include "llvm/Support/CallSite.h"
-#include "llvm/Target/TargetData.h"
+#include "llvm/DataLayout.h"
 #include "llvm/Analysis/MemoryBuiltins.h"
 #include "llvm/Transforms/Utils/BuildLibCalls.h"
 #include "llvm/Transforms/Utils/Local.h"
@@ -29,6 +29,26 @@ static Type *getPromotedType(Type *Ty) {
   return Ty;
 }
 
+/// reduceToSingleValueType - Given an aggregate type which ultimately holds a
+/// single scalar element, like {{{type}}} or [1 x type], return type.
+static Type *reduceToSingleValueType(Type *T) {
+  while (!T->isSingleValueType()) {
+    if (StructType *STy = dyn_cast<StructType>(T)) {
+      if (STy->getNumElements() == 1)
+        T = STy->getElementType(0);
+      else
+        break;
+    } else if (ArrayType *ATy = dyn_cast<ArrayType>(T)) {
+      if (ATy->getNumElements() == 1)
+        T = ATy->getElementType();
+      else
+        break;
+    } else
+      break;
+  }
+
+  return T;
+}
 
 Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) {
   unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), TD);
@@ -74,35 +94,37 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) {
   // dest address will be promotable.  See if we can find a better type than the
   // integer datatype.
   Value *StrippedDest = MI->getArgOperand(0)->stripPointerCasts();
+  MDNode *CopyMD = 0;
   if (StrippedDest != MI->getArgOperand(0)) {
     Type *SrcETy = cast<PointerType>(StrippedDest->getType())
                                     ->getElementType();
     if (TD && SrcETy->isSized() && TD->getTypeStoreSize(SrcETy) == Size) {
       // The SrcETy might be something like {{{double}}} or [1 x double].  Rip
       // down through these levels if so.
-      while (!SrcETy->isSingleValueType()) {
-        if (StructType *STy = dyn_cast<StructType>(SrcETy)) {
-          if (STy->getNumElements() == 1)
-            SrcETy = STy->getElementType(0);
-          else
-            break;
-        } else if (ArrayType *ATy = dyn_cast<ArrayType>(SrcETy)) {
-          if (ATy->getNumElements() == 1)
-            SrcETy = ATy->getElementType();
-          else
-            break;
-        } else
-          break;
-      }
+      SrcETy = reduceToSingleValueType(SrcETy);
 
       if (SrcETy->isSingleValueType()) {
         NewSrcPtrTy = PointerType::get(SrcETy, SrcAddrSp);
         NewDstPtrTy = PointerType::get(SrcETy, DstAddrSp);
+
+        // If the memcpy has metadata describing the members, see if we can
+        // get the TBAA tag describing our copy.
+        if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa_struct)) {
+          if (M->getNumOperands() == 3 &&
+              M->getOperand(0) &&
+              isa<ConstantInt>(M->getOperand(0)) &&
+              cast<ConstantInt>(M->getOperand(0))->isNullValue() &&
+              M->getOperand(1) &&
+              isa<ConstantInt>(M->getOperand(1)) &&
+              cast<ConstantInt>(M->getOperand(1))->getValue() == Size &&
+              M->getOperand(2) &&
+              isa<MDNode>(M->getOperand(2)))
+            CopyMD = cast<MDNode>(M->getOperand(2));
+        }
       }
     }
   }
 
-
   // If the memcpy/memmove provides better alignment info than we can
   // infer, use it.
   SrcAlign = std::max(SrcAlign, CopyAlign);
@@ -112,8 +134,12 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) {
   Value *Dest = Builder->CreateBitCast(MI->getArgOperand(0), NewDstPtrTy);
   LoadInst *L = Builder->CreateLoad(Src, MI->isVolatile());
   L->setAlignment(SrcAlign);
+  if (CopyMD)
+    L->setMetadata(LLVMContext::MD_tbaa, CopyMD);
   StoreInst *S = Builder->CreateStore(L, Dest, MI->isVolatile());
   S->setAlignment(DstAlign);
+  if (CopyMD)
+    S->setMetadata(LLVMContext::MD_tbaa, CopyMD);
 
   // Set the size of the copy to 0, it will be deleted on the next iteration.
   MI->setArgOperand(2, Constant::getNullValue(MemOpLength->getType()));
@@ -168,7 +194,7 @@ Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) {
 /// the heavy lifting.
 ///
 Instruction *InstCombiner::visitCallInst(CallInst &CI) {
-  if (isFreeCall(&CI))
+  if (isFreeCall(&CI, TLI))
     return visitFree(CI);
 
   // If the caller function is nounwind, mark the call as nounwind, even if the
@@ -243,7 +269,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
   default: break;
   case Intrinsic::objectsize: {
     uint64_t Size;
-    if (getObjectSize(II->getArgOperand(0), Size, TD))
+    if (getObjectSize(II->getArgOperand(0), Size, TD, TLI))
       return ReplaceInstUsesWith(CI, ConstantInt::get(CI.getType(), Size));
     return 0;
   }
@@ -731,7 +757,7 @@ Instruction *InstCombiner::visitInvokeInst(InvokeInst &II) {
 /// passed through the varargs area, we can eliminate the use of the cast.
 static bool isSafeToEliminateVarargsCast(const CallSite CS,
                                          const CastInst * const CI,
-                                         const TargetData * const TD,
+                                         const DataLayout * const TD,
                                          const int ix) {
   if (!CI->isLosslessCast())
     return false;
@@ -752,49 +778,17 @@ static bool isSafeToEliminateVarargsCast(const CallSite CS,
   return true;
 }
 
-namespace {
-class InstCombineFortifiedLibCalls : public SimplifyFortifiedLibCalls {
-  InstCombiner *IC;
-protected:
-  void replaceCall(Value *With) {
-    NewInstruction = IC->ReplaceInstUsesWith(*CI, With);
-  }
-  bool isFoldable(unsigned SizeCIOp, unsigned SizeArgOp, bool isString) const {
-    if (CI->getArgOperand(SizeCIOp) == CI->getArgOperand(SizeArgOp))
-      return true;
-    if (ConstantInt *SizeCI =
-                           dyn_cast<ConstantInt>(CI->getArgOperand(SizeCIOp))) {
-      if (SizeCI->isAllOnesValue())
-        return true;
-      if (isString) {
-        uint64_t Len = GetStringLength(CI->getArgOperand(SizeArgOp));
-        // If the length is 0 we don't know how long it is and so we can't
-        // remove the check.
-        if (Len == 0) return false;
-        return SizeCI->getZExtValue() >= Len;
-      }
-      if (ConstantInt *Arg = dyn_cast<ConstantInt>(
-                                                  CI->getArgOperand(SizeArgOp)))
-        return SizeCI->getZExtValue() >= Arg->getZExtValue();
-    }
-    return false;
-  }
-public:
-  InstCombineFortifiedLibCalls(InstCombiner *IC) : IC(IC), NewInstruction(0) { }
-  Instruction *NewInstruction;
-};
-} // end anonymous namespace
-
 // Try to fold some different type of calls here.
 // Currently we're only working with the checking functions, memcpy_chk,
 // mempcpy_chk, memmove_chk, memset_chk, strcpy_chk, stpcpy_chk, strncpy_chk,
 // strcat_chk and strncat_chk.
-Instruction *InstCombiner::tryOptimizeCall(CallInst *CI, const TargetData *TD) {
+Instruction *InstCombiner::tryOptimizeCall(CallInst *CI, const DataLayout *TD) {
   if (CI->getCalledFunction() == 0) return 0;
 
-  InstCombineFortifiedLibCalls Simplifier(this);
-  Simplifier.fold(CI, TD, TLI);
-  return Simplifier.NewInstruction;
+  if (Value *With = Simplifier->optimizeCall(CI))
+    return ReplaceInstUsesWith(*CI, With);
+
+  return 0;
 }
 
 static IntrinsicInst *FindInitTrampolineFromAlloca(Value *TrampMem) {
@@ -877,7 +871,7 @@ static IntrinsicInst *FindInitTrampoline(Value *Callee) {
 // visitCallSite - Improvements for call and invoke instructions.
 //
 Instruction *InstCombiner::visitCallSite(CallSite CS) {
-  if (isAllocLikeFn(CS.getInstruction()))
+  if (isAllocLikeFn(CS.getInstruction(), TLI))
     return visitAllocSite(*CS.getInstruction());
 
   bool Changed = false;
@@ -961,7 +955,7 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) {
     Changed = true;
   }
 
-  // Try to optimize the call if possible, we require TargetData for most of
+  // Try to optimize the call if possible, we require DataLayout for most of
   // this.  None of these calls are seen as possibly dead so go ahead and
   // delete the instruction now.
   if (CallInst *CI = dyn_cast<CallInst>(CS.getInstruction())) {
@@ -1013,8 +1007,8 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
       return false;   // Cannot transform this return value.
 
     if (!CallerPAL.isEmpty() && !Caller->use_empty()) {
-      Attributes RAttrs = CallerPAL.getRetAttributes();
-      if (RAttrs & Attribute::typeIncompatible(NewRetTy))
+      Attributes::Builder RAttrs = CallerPAL.getRetAttributes();
+      if (RAttrs.hasAttributes(Attributes::typeIncompatible(NewRetTy)))
         return false;   // Attribute not compatible with transformed value.
     }
 
@@ -1044,12 +1038,13 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
       return false;   // Cannot transform this parameter value.
 
     Attributes Attrs = CallerPAL.getParamAttributes(i + 1);
-    if (Attrs & Attribute::typeIncompatible(ParamTy))
+    if (Attributes::Builder(Attrs).
+          hasAttributes(Attributes::typeIncompatible(ParamTy)))
       return false;   // Attribute not compatible with transformed value.
 
     // If the parameter is passed as a byval argument, then we have to have a
     // sized type and the sized type has to have the same size as the old type.
-    if (ParamTy != ActTy && (Attrs & Attribute::ByVal)) {
+    if (ParamTy != ActTy && Attrs.hasAttribute(Attributes::ByVal)) {
       PointerType *ParamPTy = dyn_cast<PointerType>(ParamTy);
       if (ParamPTy == 0 || !ParamPTy->getElementType()->isSized() || TD == 0)
         return false;
@@ -1101,7 +1096,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
       if (CallerPAL.getSlot(i - 1).Index <= FT->getNumParams())
         break;
       Attributes PAttrs = CallerPAL.getSlot(i - 1).Attrs;
-      if (PAttrs & Attribute::VarArgsIncompatible)
+      if (PAttrs.hasIncompatibleWithVarArgsAttrs())
         return false;
     }
 
@@ -1114,15 +1109,17 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
   attrVec.reserve(NumCommonArgs);
 
   // Get any return attributes.
-  Attributes RAttrs = CallerPAL.getRetAttributes();
+  Attributes::Builder RAttrs = CallerPAL.getRetAttributes();
 
   // If the return value is not being used, the type may not be compatible
   // with the existing attributes.  Wipe out any problematic attributes.
-  RAttrs &= ~Attribute::typeIncompatible(NewRetTy);
+  RAttrs.removeAttributes(Attributes::typeIncompatible(NewRetTy));
 
   // Add the new return attributes.
-  if (RAttrs)
-    attrVec.push_back(AttributeWithIndex::get(0, RAttrs));
+  if (RAttrs.hasAttributes())
+    attrVec.push_back(
+      AttributeWithIndex::get(AttrListPtr::ReturnIndex,
+                              Attributes::get(FT->getContext(), RAttrs)));
 
   AI = CS.arg_begin();
   for (unsigned i = 0; i != NumCommonArgs; ++i, ++AI) {
@@ -1136,7 +1133,8 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
     }
 
     // Add any parameter attributes.
-    if (Attributes PAttrs = CallerPAL.getParamAttributes(i + 1))
+    Attributes PAttrs = CallerPAL.getParamAttributes(i + 1);
+    if (PAttrs.hasAttributes())
       attrVec.push_back(AttributeWithIndex::get(i + 1, PAttrs));
   }
 
@@ -1164,14 +1162,17 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
         }
 
         // Add any parameter attributes.
-        if (Attributes PAttrs = CallerPAL.getParamAttributes(i + 1))
+        Attributes PAttrs = CallerPAL.getParamAttributes(i + 1);
+        if (PAttrs.hasAttributes())
           attrVec.push_back(AttributeWithIndex::get(i + 1, PAttrs));
       }
     }
   }
 
-  if (Attributes FnAttrs =  CallerPAL.getFnAttributes())
-    attrVec.push_back(AttributeWithIndex::get(~0, FnAttrs));
+  Attributes FnAttrs = CallerPAL.getFnAttributes();
+  if (FnAttrs.hasAttributes())
+    attrVec.push_back(AttributeWithIndex::get(AttrListPtr::FunctionIndex,
+                                              FnAttrs));
 
   if (NewRetTy->isVoidTy())
     Caller->setName("");   // Void type should not have a name.
@@ -1240,8 +1241,9 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,
 
   // If the call already has the 'nest' attribute somewhere then give up -
   // otherwise 'nest' would occur twice after splicing in the chain.
-  if (Attrs.hasAttrSomewhere(Attribute::Nest))
-    return 0;
+  for (unsigned I = 0, E = Attrs.getNumAttrs(); I != E; ++I)
+    if (Attrs.getAttributesAtIndex(I).hasAttribute(Attributes::Nest))
+      return 0;
 
   assert(Tramp &&
          "transformCallThroughTrampoline called with incorrect CallSite.");
@@ -1254,12 +1256,12 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,
   if (!NestAttrs.isEmpty()) {
     unsigned NestIdx = 1;
     Type *NestTy = 0;
-    Attributes NestAttr = Attribute::None;
+    Attributes NestAttr;
 
     // Look for a parameter marked with the 'nest' attribute.
     for (FunctionType::param_iterator I = NestFTy->param_begin(),
          E = NestFTy->param_end(); I != E; ++NestIdx, ++I)
-      if (NestAttrs.paramHasAttr(NestIdx, Attribute::Nest)) {
+      if (NestAttrs.getParamAttributes(NestIdx).hasAttribute(Attributes::Nest)){
         // Record the parameter type and any other attributes.
         NestTy = *I;
         NestAttr = NestAttrs.getParamAttributes(NestIdx);
@@ -1278,8 +1280,10 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,
       // mean appending it.  Likewise for attributes.
 
       // Add any result attributes.
-      if (Attributes Attr = Attrs.getRetAttributes())
-        NewAttrs.push_back(AttributeWithIndex::get(0, Attr));
+      Attributes Attr = Attrs.getRetAttributes();
+      if (Attr.hasAttributes())
+        NewAttrs.push_back(AttributeWithIndex::get(AttrListPtr::ReturnIndex,
+                                                   Attr));
 
       {
         unsigned Idx = 1;
@@ -1299,7 +1303,8 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,
 
           // Add the original argument and attributes.
           NewArgs.push_back(*I);
-          if (Attributes Attr = Attrs.getParamAttributes(Idx))
+          Attr = Attrs.getParamAttributes(Idx);
+          if (Attr.hasAttributes())
             NewAttrs.push_back
               (AttributeWithIndex::get(Idx + (Idx >= NestIdx), Attr));
 
@@ -1308,8 +1313,10 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,
       }
 
       // Add any function attributes.
-      if (Attributes Attr = Attrs.getFnAttributes())
-        NewAttrs.push_back(AttributeWithIndex::get(~0, Attr));
+      Attr = Attrs.getFnAttributes();
+      if (Attr.hasAttributes())
+        NewAttrs.push_back(AttributeWithIndex::get(AttrListPtr::FunctionIndex,
+                                                   Attr));
 
       // The trampoline may have been bitcast to a bogus type (FTy).
       // Handle this by synthesizing a new function type, equal to FTy