[RewriteStatepointsForGC] Extract common code, comment, and fix a build warning ...
authorPhilip Reames <listmail@philipreames.com>
Thu, 3 Sep 2015 21:57:40 +0000 (21:57 +0000)
committerPhilip Reames <listmail@philipreames.com>
Thu, 3 Sep 2015 21:57:40 +0000 (21:57 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@246810 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/RewriteStatepointsForGC.cpp

index 0d793233f45304f14e4640a1a4a2010aca59ea46..b8b739644856b1c2c28692cdf7ec68bffd6077ed 100644 (file)
@@ -435,7 +435,7 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I) {
 
   // Due to inheritance, this must be _after_ the global variable and undef
   // checks
-  if (Constant *Con = dyn_cast<Constant>(I)) {
+  if (isa<Constant>(I)) {
     assert(!isa<GlobalVariable>(I) && !isa<UndefValue>(I) &&
            "order of checks wrong!");
     // Note: Finding a constant base for something marked for relocation
@@ -446,7 +446,7 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I) {
     // off a potentially null value and have proven it null.  We also use
     // null pointers in dead paths of relocation phis (which we might later
     // want to find a base pointer for).
-    assert(isa<ConstantPointerNull>(Con) &&
+    assert(isa<ConstantPointerNull>(I) &&
            "null is the only case which makes sense");
     return BaseDefiningValueResult(I, true);
   }
@@ -946,6 +946,34 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
     states[I] = BDVState(BDVState::Conflict, BaseInst);
   }
 
+  // Returns a instruction which produces the base pointer for a given
+  // instruction.  The instruction is assumed to be an input to one of the BDVs
+  // seen in the inference algorithm above.  As such, we must either already
+  // know it's base defining value is a base, or have inserted a new
+  // instruction to propagate the base of it's BDV and have entered that newly
+  // introduced instruction into the state table.  In either case, we are
+  // assured to be able to determine an instruction which produces it's base
+  // pointer. 
+  auto getBaseForInput = [&](Value *Input, Instruction *InsertPt) {
+    Value *BDV = findBaseOrBDV(Input, cache);
+    Value *Base = nullptr;
+    if (isKnownBaseResult(BDV)) {
+      Base = BDV;
+    } else {
+      // Either conflict or base.
+      assert(states.count(BDV));
+      Base = states[BDV].getBase();
+    }
+    assert(Base && "can't be null");
+    // The cast is needed since base traversal may strip away bitcasts
+    if (Base->getType() != Input->getType() &&
+        InsertPt) {
+      Base = new BitCastInst(Base, Input->getType(), "cast",
+                             InsertPt);
+    }
+    return Base;
+  };
+
   // Fixup all the inputs of the new PHIs
   for (auto Pair : states) {
     Instruction *v = cast<Instruction>(Pair.first);
@@ -976,15 +1004,9 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
         if (blockIndex != -1) {
           Value *oldBase = basephi->getIncomingValue(blockIndex);
           basephi->addIncoming(oldBase, InBB);
+          
 #ifndef NDEBUG
-          Value *base = findBaseOrBDV(InVal, cache);
-          if (!isKnownBaseResult(base)) {
-            // Either conflict or base.
-            assert(states.count(base));
-            base = states[base].getBase();
-            assert(base != nullptr && "unknown BDVState!");
-          }
-
+          Value *Base = getBaseForInput(InVal, nullptr);
           // In essence this assert states: the only way two
           // values incoming from the same basic block may be
           // different is by being different bitcasts of the same
@@ -992,65 +1014,36 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
           // findBaseOrBDV to return an llvm::Value of the correct
           // type (and still remain pure).  This will remove the
           // need to add bitcasts.
-          assert(base->stripPointerCasts() == oldBase->stripPointerCasts() &&
+          assert(Base->stripPointerCasts() == oldBase->stripPointerCasts() &&
                  "sanity -- findBaseOrBDV should be pure!");
 #endif
           continue;
         }
 
-        // Find either the defining value for the PHI or the normal base for
-        // a non-phi node
-        Value *base = findBaseOrBDV(InVal, cache);
-        if (!isKnownBaseResult(base)) {
-          // Either conflict or base.
-          assert(states.count(base));
-          base = states[base].getBase();
-          assert(base != nullptr && "unknown BDVState!");
-        }
-        assert(base && "can't be null");
-        // Must use original input BB since base may not be Instruction
-        // The cast is needed since base traversal may strip away bitcasts
-        if (base->getType() != basephi->getType()) {
-          base = new BitCastInst(base, basephi->getType(), "cast",
-                                 InBB->getTerminator());
-        }
-        basephi->addIncoming(base, InBB);
+        // Find the instruction which produces the base for each input.  We may
+        // need to insert a bitcast in the incoming block.
+        // TODO: Need to split critical edges if insertion is needed
+        Value *Base = getBaseForInput(InVal, InBB->getTerminator());
+        basephi->addIncoming(Base, InBB);
       }
       assert(basephi->getNumIncomingValues() == NumPHIValues);
-    } else if (SelectInst *basesel = dyn_cast<SelectInst>(state.getBase())) {
-      SelectInst *sel = cast<SelectInst>(v);
+    } else if (SelectInst *BaseSel = dyn_cast<SelectInst>(state.getBase())) {
+      SelectInst *Sel = cast<SelectInst>(v);
       // Operand 1 & 2 are true, false path respectively. TODO: refactor to
       // something more safe and less hacky.
       for (int i = 1; i <= 2; i++) {
-        Value *InVal = sel->getOperand(i);
-        // Find either the defining value for the PHI or the normal base for
-        // a non-phi node
-        Value *base = findBaseOrBDV(InVal, cache);
-        if (!isKnownBaseResult(base)) {
-          // Either conflict or base.
-          assert(states.count(base));
-          base = states[base].getBase();
-          assert(base != nullptr && "unknown BDVState!");
-        }
-        assert(base && "can't be null");
-        // Must use original input BB since base may not be Instruction
-        // The cast is needed since base traversal may strip away bitcasts
-        if (base->getType() != basesel->getType()) {
-          base = new BitCastInst(base, basesel->getType(), "cast", basesel);
-        }
-        basesel->setOperand(i, base);
+        Value *InVal = Sel->getOperand(i);
+        // Find the instruction which produces the base for each input.  We may
+        // need to insert a bitcast.
+        Value *Base = getBaseForInput(InVal, BaseSel);
+        BaseSel->setOperand(i, Base);
       }
     } else {
       auto *BaseEE = cast<ExtractElementInst>(state.getBase());
       Value *InVal = cast<ExtractElementInst>(v)->getVectorOperand();
-      Value *Base = findBaseOrBDV(InVal, cache);
-      if (!isKnownBaseResult(Base)) {
-        // Either conflict or base.
-        assert(states.count(Base));
-        Base = states[Base].getBase();
-        assert(Base != nullptr && "unknown BDVState!");
-      }
-      assert(Base && "can't be null");
+      // Find the instruction which produces the base for each input.  We may
+      // need to insert a bitcast.
+      Value *Base = getBaseForInput(InVal, BaseEE);
       BaseEE->setOperand(0, Base);
     }
   }