Fix two issues that Eli Friedman pointed out, where would misoptimized code like:
[oota-llvm.git] / lib / Transforms / Scalar / MemCpyOptimizer.cpp
index c599928d82e0c60384514393d56c1b2fac61410a..18f5f0c1306516f23e5791c4960d37e16fe5bd91 100644 (file)
@@ -554,10 +554,17 @@ bool MemCpyOpt::performCallSlotOptzn(MemCpyInst *cpy, CallInst *C) {
     User* UI = srcUseList.back();
     srcUseList.pop_back();
 
-    if (isa<GetElementPtrInst>(UI) || isa<BitCastInst>(UI)) {
+    if (isa<BitCastInst>(UI)) {
       for (User::use_iterator I = UI->use_begin(), E = UI->use_end();
            I != E; ++I)
         srcUseList.push_back(*I);
+    } else if (GetElementPtrInst* G = dyn_cast<GetElementPtrInst>(UI)) {
+      if (G->hasAllZeroIndices())
+        for (User::use_iterator I = UI->use_begin(), E = UI->use_end();
+             I != E; ++I)
+          srcUseList.push_back(*I);
+      else
+        return false;
     } else if (UI != C && UI != cpy) {
       return false;
     }
@@ -582,12 +589,16 @@ bool MemCpyOpt::performCallSlotOptzn(MemCpyInst *cpy, CallInst *C) {
   // All the checks have passed, so do the transformation.
   bool changedArgument = false;
   for (unsigned i = 0; i < CS.arg_size(); ++i)
-    if (CS.getArgument(i) == cpySrc) {
+    if (CS.getArgument(i)->stripPointerCasts() == cpySrc) {
       if (cpySrc->getType() != cpyDest->getType())
         cpyDest = CastInst::CreatePointerCast(cpyDest, cpySrc->getType(),
                                               cpyDest->getName(), C);
       changedArgument = true;
-      CS.setArgument(i, cpyDest);
+      if (CS.getArgument(i)->getType() != cpyDest->getType())
+        CS.setArgument(i, CastInst::CreatePointerCast(cpyDest, 
+                       CS.getArgument(i)->getType(), cpyDest->getName(), C));
+      else
+        CS.setArgument(i, cpyDest);
     }
 
   if (!changedArgument)