When performing return slot optimization, remember to inform memdep when we're removi...
[oota-llvm.git] / lib / Transforms / Scalar / GVN.cpp
index 26f7cb1630a36aafbc52a7f85327c646285f5b1b..481956f6b4ce039dbb5307cc558372ce9e4c5d06 100644 (file)
@@ -1083,6 +1083,8 @@ static bool isReturnSlotOptznProfitable(Value* dest, MemCpyInst* cpy) {
 /// rather than using memcpy
 bool GVN::performReturnSlotOptzn(MemCpyInst* cpy, CallInst* C,
                                  SmallVector<Instruction*, 4>& toErase) {
+  // Deliberately get the source and destination with bitcasts stripped away,
+  // because we'll need to do type comparisons based on the underlying type.
   Value* cpyDest = cpy->getDest();
   Value* cpySrc = cpy->getSource();
   CallSite CS = CallSite::get(C);
@@ -1097,23 +1099,25 @@ bool GVN::performReturnSlotOptzn(MemCpyInst* cpy, CallInst* C,
       !CS.paramHasAttr(1, ParamAttr::NoAlias | ParamAttr::StructRet))
     return false;
   
-  // We only perform the transformation if it will be profitable. 
-  if (!isReturnSlotOptznProfitable(cpyDest, cpy))
-    return false;
-  
   // Check that something sneaky is not happening involving casting
   // return slot types around.
   if (CS.getArgument(0)->getType() != cpyDest->getType())
     return false;
+  // sret --> pointer
+  const PointerType* PT = cast<PointerType>(cpyDest->getType()); 
   
   // We can only perform the transformation if the size of the memcpy
   // is constant and equal to the size of the structure.
-  if (!isa<ConstantInt>(cpy->getLength()))
+  ConstantInt* cpyLength = dyn_cast<ConstantInt>(cpy->getLength());
+  if (!cpyLength)
     return false;
   
-  ConstantInt* cpyLength = cast<ConstantInt>(cpy->getLength());
   TargetData& TD = getAnalysis<TargetData>();
-  if (TD.getTypeStoreSize(cpyDest->getType()) == cpyLength->getZExtValue())
+  if (TD.getTypeStoreSize(PT->getElementType()) != cpyLength->getZExtValue())
+    return false;
+  
+  // We only perform the transformation if it will be profitable. 
+  if (!isReturnSlotOptznProfitable(cpyDest, cpy))
     return false;
   
   // In addition to knowing that the call does not access the return slot
@@ -1135,6 +1139,7 @@ bool GVN::performReturnSlotOptzn(MemCpyInst* cpy, CallInst* C,
   MD.dropInstruction(C);
   
   // Remove the memcpy
+  MD.removeInstruction(cpy);
   toErase.push_back(cpy);
   
   return true;
@@ -1220,13 +1225,11 @@ bool GVN::processInstruction(Instruction* I,
     if (dep == MemoryDependenceAnalysis::None ||
         dep == MemoryDependenceAnalysis::NonLocal)
       return false;
-    else if (CallInst* C = dyn_cast<CallInst>(dep)) {
-      if (!isa<MemCpyInst>(C))
-        return performReturnSlotOptzn(M, C, toErase);
-    } else if (!isa<MemCpyInst>(dep))
-      return false;
-    
-    return processMemCpy(M, cast<MemCpyInst>(dep), toErase);
+    if (MemCpyInst *MemCpy = dyn_cast<MemCpyInst>(dep))
+      return processMemCpy(M, MemCpy, toErase);
+    if (CallInst* C = dyn_cast<CallInst>(dep))
+      return performReturnSlotOptzn(M, C, toErase);
+    return false;
   }
   
   unsigned num = VN.lookup_or_add(I);