[RewriteStatepointsForGC] Missed review comment from 234651 & build fix
[oota-llvm.git] / lib / Transforms / Scalar / RewriteStatepointsForGC.cpp
index efbfcd7231e56d1cddcfda821b3733c98a121a38..d0f2381e87bcb52460d9d4cb3b9c9651fd9d66c4 100644 (file)
@@ -131,38 +131,63 @@ static bool isGCPointerType(const Type *T) {
   return false;
 }
 
-/// Return true if the Value is a gc reference type which is potentially used
-/// after the instruction 'loc'.  This is only used with the edge reachability
-/// liveness code.  Note: It is assumed the V dominates loc.
-static bool isLiveGCReferenceAt(Value &V, Instruction *loc, DominatorTree &DT,
-                                LoopInfo *LI) {
-  if (!isGCPointerType(V.getType()))
-    return false;
-
-  if (V.use_empty())
-    return false;
-
-  // Given assumption that V dominates loc, this may be live
-  return true;
+// Return true if this type is one which a) is a gc pointer or contains a GC
+// pointer and b) is of a type this code expects to encounter as a live value.
+// (The insertion code will assert that a type which matches (a) and not (b)
+// is not encountered.) 
+static bool isHandledGCPointerType(Type *T) {
+  // We fully support gc pointers
+  if (isGCPointerType(T))
+    return true;
+  // We partially support vectors of gc pointers. The code will assert if it
+  // can't handle something.
+  if (auto VT = dyn_cast<VectorType>(T))
+    if (isGCPointerType(VT->getElementType()))
+      return true;
+  return false;
 }
 
 #ifndef NDEBUG
-static bool isAggWhichContainsGCPtrType(Type *Ty) {
+/// Returns true if this type contains a gc pointer whether we know how to
+/// handle that type or not.
+static bool containsGCPtrType(Type *Ty) {
+  if(isGCPointerType(Ty))
+    return true;
   if (VectorType *VT = dyn_cast<VectorType>(Ty))
     return isGCPointerType(VT->getScalarType());
   if (ArrayType *AT = dyn_cast<ArrayType>(Ty))
-    return isGCPointerType(AT->getElementType()) ||
-           isAggWhichContainsGCPtrType(AT->getElementType());
+    return containsGCPtrType(AT->getElementType());
   if (StructType *ST = dyn_cast<StructType>(Ty))
     return std::any_of(ST->subtypes().begin(), ST->subtypes().end(),
                        [](Type *SubType) {
-                         return isGCPointerType(SubType) ||
-                                isAggWhichContainsGCPtrType(SubType);
+                         return containsGCPtrType(SubType);
                        });
   return false;
 }
+
+// Returns true if this is a type which a) is a gc pointer or contains a GC
+// pointer and b) is of a type which the code doesn't expect (i.e. first class
+// aggregates).  Used to trip assertions.
+static bool isUnhandledGCPointerType(Type *Ty) {
+  return containsGCPtrType(Ty) && !isHandledGCPointerType(Ty);
+}
 #endif
 
+/// Return true if the Value is a gc reference type which is potentially used
+/// after the instruction 'loc'.  This is only used with the edge reachability
+/// liveness code.  Note: It is assumed the V dominates loc.
+static bool isLiveGCReferenceAt(Value &V, Instruction *Loc, DominatorTree &DT,
+                                LoopInfo *LI) {
+  if (!isHandledGCPointerType(V.getType()))
+    return false;
+
+  if (V.use_empty())
+    return false;
+
+  // Given assumption that V dominates loc, this may be live
+  return true;
+}
+
 // Conservatively identifies any definitions which might be live at the
 // given instruction. The  analysis is performed immediately before the
 // given instruction. Values defined by that instruction are not considered
@@ -189,7 +214,7 @@ static void findLiveGCValuesAtInst(Instruction *term, BasicBlock *pred,
   // Are there any gc pointer arguments live over this point?  This needs to be
   // special cased since arguments aren't defined in basic blocks.
   for (Argument &arg : F->args()) {
-    assert(!isAggWhichContainsGCPtrType(arg.getType()) &&
+    assert(!isUnhandledGCPointerType(arg.getType()) &&
            "support for FCA unimplemented");
 
     if (is_live_gc_reference(arg)) {
@@ -233,7 +258,7 @@ static void findLiveGCValuesAtInst(Instruction *term, BasicBlock *pred,
         break;
       }
 
-      assert(!isAggWhichContainsGCPtrType(inst.getType()) &&
+      assert(!isUnhandledGCPointerType(inst.getType()) &&
              "support for FCA unimplemented");
 
       if (is_live_gc_reference(inst)) {
@@ -299,10 +324,49 @@ analyzeParsePointLiveness(DominatorTree &DT, const CallSite &CS,
   result.liveset = liveset;
 }
 
-/// True iff this value is the null pointer constant (of any pointer type)
-static bool LLVM_ATTRIBUTE_UNUSED isNullConstant(Value *V) {
-  return isa<Constant>(V) && isa<PointerType>(V->getType()) &&
-         cast<Constant>(V)->isNullValue();
+/// If we can trivially determine that this vector contains only base pointers,
+/// return the base instruction.  
+static Value *findBaseOfVector(Value *I) {
+  assert(I->getType()->isVectorTy() &&
+         cast<VectorType>(I->getType())->getElementType()->isPointerTy() &&
+         "Illegal to ask for the base pointer of a non-pointer type");
+
+  // Each case parallels findBaseDefiningValue below, see that code for
+  // detailed motivation.
+
+  if (isa<Argument>(I))
+    // An incoming argument to the function is a base pointer
+    return I;
+
+  // We shouldn't see the address of a global as a vector value?
+  assert(!isa<GlobalVariable>(I) &&
+         "unexpected global variable found in base of vector");
+
+  // inlining could possibly introduce phi node that contains
+  // undef if callee has multiple returns
+  if (isa<UndefValue>(I))
+    // utterly meaningless, but useful for dealing with partially optimized
+    // code.
+    return I; 
+
+  // Due to inheritance, this must be _after_ the global variable and undef
+  // checks
+  if (Constant *Con = dyn_cast<Constant>(I)) {
+    assert(!isa<GlobalVariable>(I) && !isa<UndefValue>(I) &&
+           "order of checks wrong!");
+    assert(Con->isNullValue() && "null is the only case which makes sense");
+    return Con;
+  }
+
+  if (isa<LoadInst>(I))
+    return I;
+
+  // Note: This code is currently rather incomplete.  We are essentially only
+  // handling cases where the vector element is trivially a base pointer.  We
+  // need to update the entire base pointer construction algorithm to know how
+  // to track vector elements and potentially scalarize, but the case which
+  // would motivate the work hasn't shown up in real workloads yet.
+  llvm_unreachable("no base found for vector element");
 }
 
 /// Helper function for findBasePointer - Will return a value which either a)
@@ -312,10 +376,16 @@ static Value *findBaseDefiningValue(Value *I) {
   assert(I->getType()->isPointerTy() &&
          "Illegal to ask for the base pointer of a non-pointer type");
 
-  // There are instructions which can never return gc pointer values.  Sanity
-  // check that this is actually true.
-  assert(!isa<InsertElementInst>(I) && !isa<ExtractElementInst>(I) &&
-         !isa<ShuffleVectorInst>(I) && "Vector types are not gc pointers");
+  // This case is a bit of a hack - it only handles extracts from vectors which
+  // trivially contain only base pointers.  See note inside the function for
+  // how to improve this.
+  if (auto *EEI = dyn_cast<ExtractElementInst>(I)) {
+    Value *VectorOperand = EEI->getVectorOperand();
+    Value *VectorBase = findBaseOfVector(VectorOperand);
+    (void)VectorBase;
+    assert(VectorBase && "extract element not known to be a trivial base");
+    return EEI;
+  }
 
   if (isa<Argument>(I))
     // An incoming argument to the function is a base pointer
@@ -346,9 +416,8 @@ static Value *findBaseDefiningValue(Value *I) {
     // off a potentially null value and have proven it null.  We also use
     // null pointers in dead paths of relocation phis (which we might later
     // want to find a base pointer for).
-    assert(Con->getType()->isPointerTy() &&
-           "Base for pointer must be another pointer");
-    assert(Con->isNullValue() && "null is the only case which makes sense");
+    assert(isa<ConstantPointerNull>(Con) &&
+           "null is the only case which makes sense");
     return Con;
   }
 
@@ -432,15 +501,15 @@ static Value *findBaseDefiningValue(Value *I) {
 }
 
 /// Returns the base defining value for this value.
-static Value *findBaseDefiningValueCached(Value *I, DefiningValueMapTy &cache) {
-  Value *&Cached = cache[I];
+static Value *findBaseDefiningValueCached(Value *I, DefiningValueMapTy &Cache) {
+  Value *&Cached = Cache[I];
   if (!Cached) {
     Cached = findBaseDefiningValue(I);
   }
-  assert(cache[I] != nullptr);
+  assert(Cache[I] != nullptr);
 
   if (TraceLSP) {
-    errs() << "fBDV-cached: " << I->getName() << " -> " << Cached->getName()
+    dbgs() << "fBDV-cached: " << I->getName() << " -> " << Cached->getName()
            << "\n";
   }
   return Cached;
@@ -448,25 +517,26 @@ static Value *findBaseDefiningValueCached(Value *I, DefiningValueMapTy &cache) {
 
 /// Return a base pointer for this value if known.  Otherwise, return it's
 /// base defining value.
-static Value *findBaseOrBDV(Value *I, DefiningValueMapTy &cache) {
-  Value *def = findBaseDefiningValueCached(I, cache);
-  auto Found = cache.find(def);
-  if (Found != cache.end()) {
+static Value *findBaseOrBDV(Value *I, DefiningValueMapTy &Cache) {
+  Value *Def = findBaseDefiningValueCached(I, Cache);
+  auto Found = Cache.find(Def);
+  if (Found != Cache.end()) {
     // Either a base-of relation, or a self reference.  Caller must check.
     return Found->second;
   }
   // Only a BDV available
-  return def;
+  return Def;
 }
 
 /// Given the result of a call to findBaseDefiningValue, or findBaseOrBDV,
 /// is it known to be a base pointer?  Or do we need to continue searching.
-static bool isKnownBaseResult(Value *v) {
-  if (!isa<PHINode>(v) && !isa<SelectInst>(v)) {
+static bool isKnownBaseResult(Value *V) {
+  if (!isa<PHINode>(V) && !isa<SelectInst>(V)) {
     // no recursion possible
     return true;
   }
-  if (cast<Instruction>(v)->getMetadata("is_base_value")) {
+  if (isa<Instruction>(V) &&
+      cast<Instruction>(V)->getMetadata("is_base_value")) {
     // This is a previously inserted base phi or select.  We know
     // that this is a base value.
     return true;
@@ -940,10 +1010,10 @@ static void findBasePointers(const StatepointLiveSetTy &live,
     // If you see this trip and like to live really dangerously, the code should
     // be correct, just with idioms the verifier can't handle.  You can try
     // disabling the verifier at your own substaintial risk.
-    assert(!isNullConstant(base) && "the relocation code needs adjustment to "
-                                    "handle the relocation of a null pointer "
-                                    "constant without causing false positives "
-                                    "in the safepoint ir verifier.");
+    assert(!isa<ConstantPointerNull>(base) && 
+           "the relocation code needs adjustment to handle the relocation of "
+           "a null pointer constant without causing false positives in the "
+           "safepoint ir verifier.");
   }
 }
 
@@ -1397,14 +1467,13 @@ static void relocationViaAlloca(
     Function &F, DominatorTree &DT, ArrayRef<Value *> live,
     ArrayRef<struct PartiallyConstructedSafepointRecord> records) {
 #ifndef NDEBUG
-  int initialAllocaNum = 0;
-
-  // record initial number of allocas
-  for (inst_iterator itr = inst_begin(F), end = inst_end(F); itr != end;
-       itr++) {
-    if (isa<AllocaInst>(*itr))
-      initialAllocaNum++;
-  }
+  // record initial number of (static) allocas; we'll check we have the same
+  // number when we get done.
+  int InitialAllocaNum = 0;
+  for (auto I = F.getEntryBlock().begin(), E = F.getEntryBlock().end(); 
+       I != E; I++)
+    if (isa<AllocaInst>(*I))
+      InitialAllocaNum++;
 #endif
 
   // TODO-PERF: change data structures, reserve
@@ -1546,7 +1615,7 @@ static void relocationViaAlloca(
       }
     } else {
       assert((isa<Argument>(def) || isa<GlobalVariable>(def) ||
-              (isa<Constant>(def) && cast<Constant>(def)->isNullValue())) &&
+              isa<ConstantPointerNull>(def)) &&
              "Must be argument or global");
       store->insertAfter(cast<Instruction>(alloca));
     }
@@ -1560,12 +1629,11 @@ static void relocationViaAlloca(
   }
 
 #ifndef NDEBUG
-  for (inst_iterator itr = inst_begin(F), end = inst_end(F); itr != end;
-       itr++) {
-    if (isa<AllocaInst>(*itr))
-      initialAllocaNum--;
-  }
-  assert(initialAllocaNum == 0 && "We must not introduce any extra allocas");
+  for (auto I = F.getEntryBlock().begin(), E = F.getEntryBlock().end(); 
+       I != E; I++)
+    if (isa<AllocaInst>(*I))
+      InitialAllocaNum--;
+  assert(InitialAllocaNum == 0 && "We must not introduce any extra allocas");
 #endif
 }
 
@@ -1658,6 +1726,117 @@ static void addBasesAsLiveValues(StatepointLiveSetTy &liveset,
   assert(liveset.size() == PointerToBase.size());
 }
 
+/// Remove any vector of pointers from the liveset by scalarizing them over the
+/// statepoint instruction.  Adds the scalarized pieces to the liveset.  It
+/// would be preferrable to include the vector in the statepoint itself, but
+/// the lowering code currently does not handle that.  Extending it would be
+/// slightly non-trivial since it requires a format change.  Given how rare
+/// such cases are (for the moment?) scalarizing is an acceptable comprimise.
+static void splitVectorValues(Instruction *StatepointInst,
+                              StatepointLiveSetTy& LiveSet, DominatorTree &DT) {
+  SmallVector<Value *, 16> ToSplit;
+  for (Value *V : LiveSet)
+    if (isa<VectorType>(V->getType()))
+      ToSplit.push_back(V);
+
+  if (ToSplit.empty())
+    return;
+
+  Function &F = *(StatepointInst->getParent()->getParent());
+
+  DenseMap<Value*, AllocaInst*> AllocaMap;
+  // First is normal return, second is exceptional return (invoke only)
+  DenseMap<Value*, std::pair<Value*,Value*>> Replacements;
+  for (Value *V : ToSplit) {
+    LiveSet.erase(V);
+
+    AllocaInst *Alloca = new AllocaInst(V->getType(), "",
+                                        F.getEntryBlock().getFirstNonPHI());
+    AllocaMap[V] = Alloca;
+
+    VectorType *VT = cast<VectorType>(V->getType());
+    IRBuilder<> Builder(StatepointInst);
+    SmallVector<Value*, 16> Elements;
+    for (unsigned i = 0; i < VT->getNumElements(); i++)
+      Elements.push_back(Builder.CreateExtractElement(V, Builder.getInt32(i)));
+    LiveSet.insert(Elements.begin(), Elements.end());
+
+    auto InsertVectorReform = [&](Instruction *IP) {
+      Builder.SetInsertPoint(IP);
+      Builder.SetCurrentDebugLocation(IP->getDebugLoc());
+      Value *ResultVec = UndefValue::get(VT);
+      for (unsigned i = 0; i < VT->getNumElements(); i++)
+        ResultVec = Builder.CreateInsertElement(ResultVec, Elements[i],
+                                                Builder.getInt32(i));
+      return ResultVec;
+    };
+
+    if (isa<CallInst>(StatepointInst)) {
+      BasicBlock::iterator Next(StatepointInst);
+      Next++;
+      Instruction *IP = &*(Next);
+      Replacements[V].first = InsertVectorReform(IP);
+      Replacements[V].second = nullptr;
+    } else {
+      InvokeInst *Invoke = cast<InvokeInst>(StatepointInst);
+      // We've already normalized - check that we don't have shared destination
+      // blocks 
+      BasicBlock *NormalDest = Invoke->getNormalDest();
+      assert(!isa<PHINode>(NormalDest->begin()));
+      BasicBlock *UnwindDest = Invoke->getUnwindDest();
+      assert(!isa<PHINode>(UnwindDest->begin()));
+      // Insert insert element sequences in both successors
+      Instruction *IP = &*(NormalDest->getFirstInsertionPt());
+      Replacements[V].first = InsertVectorReform(IP);
+      IP = &*(UnwindDest->getFirstInsertionPt());
+      Replacements[V].second = InsertVectorReform(IP);
+    }
+  }
+  for (Value *V : ToSplit) {
+    AllocaInst *Alloca = AllocaMap[V];
+
+    // Capture all users before we start mutating use lists
+    SmallVector<Instruction*, 16> Users;
+    for (User *U : V->users())
+      Users.push_back(cast<Instruction>(U));
+
+    for (Instruction *I : Users) {
+      if (auto Phi = dyn_cast<PHINode>(I)) {
+        for (unsigned i = 0; i < Phi->getNumIncomingValues(); i++)
+          if (V == Phi->getIncomingValue(i)) {
+            LoadInst *Load = new LoadInst(Alloca, "",
+                                 Phi->getIncomingBlock(i)->getTerminator());
+            Phi->setIncomingValue(i, Load);
+          }
+      } else {
+        LoadInst *Load = new LoadInst(Alloca, "", I);
+        I->replaceUsesOfWith(V, Load);
+      }
+    }
+
+    // Store the original value and the replacement value into the alloca
+    StoreInst *Store = new StoreInst(V, Alloca);
+    if (auto I = dyn_cast<Instruction>(V))
+      Store->insertAfter(I);
+    else
+      Store->insertAfter(Alloca);
+    
+    // Normal return for invoke, or call return
+    Instruction *Replacement = cast<Instruction>(Replacements[V].first);
+    (new StoreInst(Replacement, Alloca))->insertAfter(Replacement);
+    // Unwind return for invoke only
+    Replacement = cast_or_null<Instruction>(Replacements[V].second);
+    if (Replacement)
+      (new StoreInst(Replacement, Alloca))->insertAfter(Replacement);
+  }
+
+  // apply mem2reg to promote alloca to SSA
+  SmallVector<AllocaInst*, 16> Allocas;
+  for (Value *V : ToSplit)
+    Allocas.push_back(AllocaMap[V]);
+  PromoteMemToReg(Allocas, DT);
+}
+
 static bool insertParsePoints(Function &F, DominatorTree &DT, Pass *P,
                               SmallVectorImpl<CallSite> &toUpdate) {
 #ifndef NDEBUG
@@ -1688,7 +1867,9 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, Pass *P,
     SmallVector<Value *, 64> DeoptValues;
     for (Use &U : StatepointCS.vm_state_args()) {
       Value *Arg = cast<Value>(&U);
-      if (isGCPointerType(Arg->getType()))
+      assert(!isUnhandledGCPointerType(Arg->getType()) &&
+             "support for FCA unimplemented");
+      if (isHandledGCPointerType(Arg->getType()))
         DeoptValues.push_back(Arg);
     }
     insertUseHolderAfter(CS, DeoptValues, holders);
@@ -1706,6 +1887,17 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, Pass *P,
   // site.
   findLiveReferences(F, DT, P, toUpdate, records);
 
+  // Do a limited scalarization of any live at safepoint vector values which
+  // contain pointers.  This enables this pass to run after vectorization at
+  // the cost of some possible performance loss.  TODO: it would be nice to
+  // natively support vectors all the way through the backend so we don't need
+  // to scalarize here.
+  for (size_t i = 0; i < records.size(); i++) {
+    struct PartiallyConstructedSafepointRecord &info = records[i];
+    Instruction *statepoint = toUpdate[i].getInstruction();
+    splitVectorValues(cast<Instruction>(statepoint), info.liveset, DT);
+  }
+
   // B) Find the base pointers for each live pointer
   /* scope for caching */ {
     // Cache the 'defining value' relation used in the computation and
@@ -1863,18 +2055,46 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F) {
   if (!shouldRewriteStatepointsIn(F))
     return false;
 
-  // Gather all the statepoints which need rewritten.
+  DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+  
+  // Gather all the statepoints which need rewritten.  Be careful to only
+  // consider those in reachable code since we need to ask dominance queries
+  // when rewriting.  We'll delete the unreachable ones in a moment.
   SmallVector<CallSite, 64> ParsePointNeeded;
+  bool HasUnreachableStatepoint = false;
   for (Instruction &I : inst_range(F)) {
     // TODO: only the ones with the flag set!
-    if (isStatepoint(I))
-      ParsePointNeeded.push_back(CallSite(&I));
+    if (isStatepoint(I)) {
+      if (DT.isReachableFromEntry(I.getParent()))
+        ParsePointNeeded.push_back(CallSite(&I));
+      else
+        HasUnreachableStatepoint = true;
+    }
   }
 
+  bool MadeChange = false;
+  
+  // Delete any unreachable statepoints so that we don't have unrewritten
+  // statepoints surviving this pass.  This makes testing easier and the
+  // resulting IR less confusing to human readers.  Rather than be fancy, we
+  // just reuse a utility function which removes the unreachable blocks.
+  if (HasUnreachableStatepoint)
+    MadeChange |= removeUnreachableBlocks(F);
+
   // Return early if no work to do.
   if (ParsePointNeeded.empty())
-    return false;
+    return MadeChange;
+
+  // As a prepass, go ahead and aggressively destroy single entry phi nodes.
+  // These are created by LCSSA.  They have the effect of increasing the size
+  // of liveness sets for no good reason.  It may be harder to do this post
+  // insertion since relocations and base phis can confuse things.
+  for (BasicBlock &BB : F)
+    if (BB.getUniquePredecessor()) {
+      MadeChange = true;
+      FoldSingleEntryPHINodes(&BB);
+    }
 
-  DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
-  return insertParsePoints(F, DT, this, ParsePointNeeded);
+  MadeChange |= insertParsePoints(F, DT, this, ParsePointNeeded);
+  return MadeChange;
 }