De-constify pointers to Type since they can't be modified. NFC
[oota-llvm.git] / lib / Transforms / IPO / MergeFunctions.cpp
index 6ef1ac77da32c75ec0e0a606d50b6f2ef6af91e3..dba2d1bee601c44d5fc545cd9c9a88d2c2cda840 100644 (file)
@@ -500,9 +500,9 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) {
     unsigned TyLWidth = 0;
     unsigned TyRWidth = 0;
 
-    if (const VectorType *VecTyL = dyn_cast<VectorType>(TyL))
+    if (auto *VecTyL = dyn_cast<VectorType>(TyL))
       TyLWidth = VecTyL->getBitWidth();
-    if (const VectorType *VecTyR = dyn_cast<VectorType>(TyR))
+    if (auto *VecTyR = dyn_cast<VectorType>(TyR))
       TyRWidth = VecTyR->getBitWidth();
 
     if (TyLWidth != TyRWidth)
@@ -1282,6 +1282,25 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
     ++UI;
     CallSite CS(U->getUser());
     if (CS && CS.isCallee(U)) {
+      // Transfer the called function's attributes to the call site. Due to the
+      // bitcast we will 'loose' ABI changing attributes because the 'called
+      // function' is no longer a Function* but the bitcast. Code that looks up
+      // the attributes from the called function will fail.
+      auto &Context = New->getContext();
+      auto NewFuncAttrs = New->getAttributes();
+      auto CallSiteAttrs = CS.getAttributes();
+
+      CallSiteAttrs = CallSiteAttrs.addAttributes(
+          Context, AttributeSet::ReturnIndex, NewFuncAttrs.getRetAttributes());
+
+      for (unsigned argIdx = 0; argIdx < CS.arg_size(); argIdx++) {
+        AttributeSet Attrs = NewFuncAttrs.getParamAttributes(argIdx);
+        if (Attrs.getNumSlots())
+          CallSiteAttrs = CallSiteAttrs.addAttributes(Context, argIdx, Attrs);
+      }
+
+      CS.setAttributes(CallSiteAttrs);
+
       remove(CS.getInstruction()->getParent()->getParent());
       U->set(BitcastNew);
     }
@@ -1397,28 +1416,26 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) {
   if (F->mayBeOverridden()) {
     assert(G->mayBeOverridden());
 
-    if (HasGlobalAliases) {
-      // Make them both thunks to the same internal function.
-      Function *H = Function::Create(F->getFunctionType(), F->getLinkage(), "",
-                                     F->getParent());
-      H->copyAttributesFrom(F);
-      H->takeName(F);
-      removeUsers(F);
-      F->replaceAllUsesWith(H);
+    // Make them both thunks to the same internal function.
+    Function *H = Function::Create(F->getFunctionType(), F->getLinkage(), "",
+                                   F->getParent());
+    H->copyAttributesFrom(F);
+    H->takeName(F);
+    removeUsers(F);
+    F->replaceAllUsesWith(H);
 
-      unsigned MaxAlignment = std::max(G->getAlignment(), H->getAlignment());
+    unsigned MaxAlignment = std::max(G->getAlignment(), H->getAlignment());
 
+    if (HasGlobalAliases) {
       writeAlias(F, G);
       writeAlias(F, H);
-
-      F->setAlignment(MaxAlignment);
-      F->setLinkage(GlobalValue::PrivateLinkage);
     } else {
-      // We can't merge them. Instead, pick one and update all direct callers
-      // to call it and hope that we improve the instruction cache hit rate.
-      replaceDirectCallers(G, F);
+      writeThunk(F, G);
+      writeThunk(F, H);
     }
 
+    F->setAlignment(MaxAlignment);
+    F->setLinkage(GlobalValue::PrivateLinkage);
     ++NumDoubleWeak;
   } else {
     writeThunkOrAlias(F, G);
@@ -1434,9 +1451,10 @@ void MergeFunctions::replaceFunctionInTree(FnTreeType::iterator &IterToF,
 
   // A total order is already guaranteed otherwise because we process strong
   // functions before weak functions.
-  assert((F->mayBeOverridden() && G->mayBeOverridden()) ||
-         (!F->mayBeOverridden() && !G->mayBeOverridden()) &&
+  assert(((F->mayBeOverridden() && G->mayBeOverridden()) ||
+          (!F->mayBeOverridden() && !G->mayBeOverridden())) &&
          "Only change functions if both are strong or both are weak");
+  (void)F;
 
   IterToF->replaceBy(G);
 }
@@ -1517,6 +1535,8 @@ void MergeFunctions::remove(Function *F) {
 void MergeFunctions::removeUsers(Value *V) {
   std::vector<Value *> Worklist;
   Worklist.push_back(V);
+  SmallSet<Value*, 8> Visited;
+  Visited.insert(V);
   while (!Worklist.empty()) {
     Value *V = Worklist.back();
     Worklist.pop_back();
@@ -1527,8 +1547,10 @@ void MergeFunctions::removeUsers(Value *V) {
       } else if (isa<GlobalValue>(U)) {
         // do nothing
       } else if (Constant *C = dyn_cast<Constant>(U)) {
-        for (User *UU : C->users())
-          Worklist.push_back(UU);
+        for (User *UU : C->users()) {
+          if (!Visited.insert(UU).second)
+            Worklist.push_back(UU);
+        }
       }
     }
   }