[RewriteStatepointsForGC] Generalized vector phi/select handling for base pointers
authorPhilip Reames <listmail@philipreames.com>
Fri, 26 Jun 2015 22:47:37 +0000 (22:47 +0000)
committerPhilip Reames <listmail@philipreames.com>
Fri, 26 Jun 2015 22:47:37 +0000 (22:47 +0000)
This change extends the detection of base pointers for vector constructs to handle arbitrary phi and select nodes. The existing non-vector code already handles those, so this is basically just extending the vector special case to be less special cased. It still isn't generalized vector handling since we can't handle arbitrary vector instructions (e.g. shufflevectors), but it's a lot closer.

The general structure of the change is as follows:
 * Extend the base defining value relation over a subset of vector instructions and vector typed phi & select instructions.
 * Move scalarization from before base pointer rewriting to after base pointer rewriting. The extension of the BDV relation is sufficient to find vector base phis for vector inputs.
 * Preserve the existing special case logic for when the base of a vector element is locally obvious. This general idea could be extended to the scalar case as well.

Differential Revision: http://reviews.llvm.org/D10461#inline-84275

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@240850 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
test/Transforms/RewriteStatepointsForGC/live-vector.ll

index 21ff55c697e3cf169ae6813ea8cd50a555005791..ae2ae3af0c7a046c6f6c5d72e0f20f820a8df72e 100644 (file)
@@ -294,12 +294,17 @@ static void analyzeParsePointLiveness(
 
 static Value *findBaseDefiningValue(Value *I);
 
-/// If we can trivially determine that the index specified in the given vector
-/// is a base pointer, return it.  In cases where the entire vector is known to
-/// consist of base pointers, the entire vector will be returned.  This
-/// indicates that the relevant extractelement is a valid base pointer and
-/// should be used directly.
-static Value *findBaseOfVector(Value *I, Value *Index) {
+/// Return a base defining value for the 'Index' element of the given vector
+/// instruction 'I'.  If Index is null, returns a BDV for the entire vector
+/// 'I'.  As an optimization, this method will try to determine when the 
+/// element is known to already be a base pointer.  If this can be established,
+/// the second value in the returned pair will be true.  Note that either a
+/// vector or a pointer typed value can be returned.  For the former, the
+/// vector returned is a BDV (and possibly a base) of the entire vector 'I'.
+/// If the later, the return pointer is a BDV (or possibly a base) for the
+/// particular element in 'I'.  
+static std::pair<Value *, bool>
+findBaseDefiningValueOfVector(Value *I, Value *Index = nullptr) {
   assert(I->getType()->isVectorTy() &&
          cast<VectorType>(I->getType())->getElementType()->isPointerTy() &&
          "Illegal to ask for the base pointer of a non-pointer type");
@@ -309,7 +314,7 @@ static Value *findBaseOfVector(Value *I, Value *Index) {
 
   if (isa<Argument>(I))
     // An incoming argument to the function is a base pointer
-    return I;
+    return std::make_pair(I, true);
 
   // We shouldn't see the address of a global as a vector value?
   assert(!isa<GlobalVariable>(I) &&
@@ -320,7 +325,7 @@ static Value *findBaseOfVector(Value *I, Value *Index) {
   if (isa<UndefValue>(I))
     // utterly meaningless, but useful for dealing with partially optimized
     // code.
-    return I;
+    return std::make_pair(I, true);
 
   // Due to inheritance, this must be _after_ the global variable and undef
   // checks
@@ -328,38 +333,56 @@ static Value *findBaseOfVector(Value *I, Value *Index) {
     assert(!isa<GlobalVariable>(I) && !isa<UndefValue>(I) &&
            "order of checks wrong!");
     assert(Con->isNullValue() && "null is the only case which makes sense");
-    return Con;
+    return std::make_pair(Con, true);
   }
-
+  
   if (isa<LoadInst>(I))
-    return I;
-
+    return std::make_pair(I, true);
+  
   // For an insert element, we might be able to look through it if we know
-  // something about the indexes, but if the indices are arbitrary values, we
-  // can't without much more extensive scalarization.
+  // something about the indexes.
   if (InsertElementInst *IEI = dyn_cast<InsertElementInst>(I)) {
-    Value *InsertIndex = IEI->getOperand(2);
-    // This index is inserting the value, look for it's base
-    if (InsertIndex == Index)
-      return findBaseDefiningValue(IEI->getOperand(1));
-    // Both constant, and can't be equal per above. This insert is definitely
-    // not relevant, look back at the rest of the vector and keep trying.
-    if (isa<ConstantInt>(Index) && isa<ConstantInt>(InsertIndex))
-      return findBaseOfVector(IEI->getOperand(0), Index);
-  }
-
-  // Note: This code is currently rather incomplete.  We are essentially only
-  // handling cases where the vector element is trivially a base pointer.  We
-  // need to update the entire base pointer construction algorithm to know how
-  // to track vector elements and potentially scalarize, but the case which
-  // would motivate the work hasn't shown up in real workloads yet.
-  llvm_unreachable("no base found for vector element");
+    if (Index) {
+      Value *InsertIndex = IEI->getOperand(2);
+      // This index is inserting the value, look for its BDV
+      if (InsertIndex == Index)
+        return std::make_pair(findBaseDefiningValue(IEI->getOperand(1)), false);
+      // Both constant, and can't be equal per above. This insert is definitely
+      // not relevant, look back at the rest of the vector and keep trying.
+      if (isa<ConstantInt>(Index) && isa<ConstantInt>(InsertIndex))
+        return findBaseDefiningValueOfVector(IEI->getOperand(0), Index);
+    }
+    
+    // We don't know whether this vector contains entirely base pointers or
+    // not.  To be conservatively correct, we treat it as a BDV and will
+    // duplicate code as needed to construct a parallel vector of bases.
+    return std::make_pair(IEI, false);
+  }
+
+  if (isa<ShuffleVectorInst>(I))
+    // We don't know whether this vector contains entirely base pointers or
+    // not.  To be conservatively correct, we treat it as a BDV and will
+    // duplicate code as needed to construct a parallel vector of bases.
+    // TODO: There a number of local optimizations which could be applied here
+    // for particular sufflevector patterns.
+    return std::make_pair(I, false);
+
+  // A PHI or Select is a base defining value.  The outer findBasePointer
+  // algorithm is responsible for constructing a base value for this BDV.
+  assert((isa<SelectInst>(I) || isa<PHINode>(I)) &&
+         "unknown vector instruction - no base found for vector element");
+  return std::make_pair(I, false);
 }
 
+static bool isKnownBaseResult(Value *V);
+
 /// Helper function for findBasePointer - Will return a value which either a)
 /// defines the base pointer for the input or b) blocks the simple search
 /// (i.e. a PHI or Select of two derived pointers)
 static Value *findBaseDefiningValue(Value *I) {
+  if (I->getType()->isVectorTy())
+    return findBaseDefiningValueOfVector(I).first;
+  
   assert(I->getType()->isPointerTy() &&
          "Illegal to ask for the base pointer of a non-pointer type");
 
@@ -370,16 +393,39 @@ static Value *findBaseDefiningValue(Value *I) {
   if (auto *EEI = dyn_cast<ExtractElementInst>(I)) {
     Value *VectorOperand = EEI->getVectorOperand();
     Value *Index = EEI->getIndexOperand();
-    Value *VectorBase = findBaseOfVector(VectorOperand, Index);
-    // If the result returned is a vector, we know the entire vector must
-    // contain base pointers.  In that case, the extractelement is a valid base
-    // for this value.
-    if (VectorBase->getType()->isVectorTy())
-      return EEI;
-    // Otherwise, we needed to look through the vector to find the base for
-    // this particular element.
-    assert(VectorBase->getType()->isPointerTy());
-    return VectorBase;
+    std::pair<Value *, bool> pair =
+      findBaseDefiningValueOfVector(VectorOperand, Index);
+    Value *VectorBase = pair.first;
+    if (VectorBase->getType()->isPointerTy())
+      // We found a BDV for this specific element with the vector.  This is an
+      // optimization, but in practice it covers most of the useful cases
+      // created via scalarization.
+      return VectorBase;
+    else {
+      assert(VectorBase->getType()->isVectorTy());
+      if (pair.second)
+        // If the entire vector returned is known to be entirely base pointers,
+        // then the extractelement is valid base for this value.
+        return EEI;
+      else {
+        // Otherwise, we have an instruction which potentially produces a
+        // derived pointer and we need findBasePointers to clone code for us
+        // such that we can create an instruction which produces the
+        // accompanying base pointer.
+        // Note: This code is currently rather incomplete.  We don't currently
+        // support the general form of shufflevector of insertelement.
+        // Conceptually, these are just 'base defining values' of the same
+        // variety as phi or select instructions.  We need to update the
+        // findBasePointers algorithm to insert new 'base-only' versions of the
+        // original instructions. This is relative straight forward to do, but
+        // the case which would motivate the work hasn't shown up in real
+        // workloads yet.  
+        assert((isa<PHINode>(VectorBase) || isa<SelectInst>(VectorBase)) &&
+               "need to extend findBasePointers for generic vector"
+               "instruction cases");
+        return VectorBase;
+      }
+    }
   }
 
   if (isa<Argument>(I))
@@ -1712,7 +1758,9 @@ static void findLiveReferences(
 /// slightly non-trivial since it requires a format change.  Given how rare
 /// such cases are (for the moment?) scalarizing is an acceptable comprimise.
 static void splitVectorValues(Instruction *StatepointInst,
-                              StatepointLiveSetTy &LiveSet, DominatorTree &DT) {
+                              StatepointLiveSetTy &LiveSet,
+                              DenseMap<Value *, Value *>& PointerToBase,
+                              DominatorTree &DT) {
   SmallVector<Value *, 16> ToSplit;
   for (Value *V : LiveSet)
     if (isa<VectorType>(V->getType()))
@@ -1721,14 +1769,14 @@ static void splitVectorValues(Instruction *StatepointInst,
   if (ToSplit.empty())
     return;
 
+  DenseMap<Value *, SmallVector<Value *, 16>> ElementMapping;
+
   Function &F = *(StatepointInst->getParent()->getParent());
 
   DenseMap<Value *, AllocaInst *> AllocaMap;
   // First is normal return, second is exceptional return (invoke only)
   DenseMap<Value *, std::pair<Value *, Value *>> Replacements;
   for (Value *V : ToSplit) {
-    LiveSet.erase(V);
-
     AllocaInst *Alloca =
         new AllocaInst(V->getType(), "", F.getEntryBlock().getFirstNonPHI());
     AllocaMap[V] = Alloca;
@@ -1738,7 +1786,7 @@ static void splitVectorValues(Instruction *StatepointInst,
     SmallVector<Value *, 16> Elements;
     for (unsigned i = 0; i < VT->getNumElements(); i++)
       Elements.push_back(Builder.CreateExtractElement(V, Builder.getInt32(i)));
-    LiveSet.insert(Elements.begin(), Elements.end());
+    ElementMapping[V] = Elements;
 
     auto InsertVectorReform = [&](Instruction *IP) {
       Builder.SetInsertPoint(IP);
@@ -1771,6 +1819,7 @@ static void splitVectorValues(Instruction *StatepointInst,
       Replacements[V].second = InsertVectorReform(IP);
     }
   }
+
   for (Value *V : ToSplit) {
     AllocaInst *Alloca = AllocaMap[V];
 
@@ -1814,6 +1863,25 @@ static void splitVectorValues(Instruction *StatepointInst,
   for (Value *V : ToSplit)
     Allocas.push_back(AllocaMap[V]);
   PromoteMemToReg(Allocas, DT);
+
+  // Update our tracking of live pointers and base mappings to account for the
+  // changes we just made.
+  for (Value *V : ToSplit) {
+    auto &Elements = ElementMapping[V];
+
+    LiveSet.erase(V);
+    LiveSet.insert(Elements.begin(), Elements.end());
+    // We need to update the base mapping as well.
+    assert(PointerToBase.count(V));
+    Value *OldBase = PointerToBase[V];
+    auto &BaseElements = ElementMapping[OldBase];
+    PointerToBase.erase(V);
+    assert(Elements.size() == BaseElements.size());
+    for (unsigned i = 0; i < Elements.size(); i++) {
+      Value *Elem = Elements[i];
+      PointerToBase[Elem] = BaseElements[i];
+    }
+  }
 }
 
 // Helper function for the "rematerializeLiveValues". It walks use chain
@@ -2075,17 +2143,6 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, Pass *P,
   // site.
   findLiveReferences(F, DT, P, toUpdate, records);
 
-  // Do a limited scalarization of any live at safepoint vector values which
-  // contain pointers.  This enables this pass to run after vectorization at
-  // the cost of some possible performance loss.  TODO: it would be nice to
-  // natively support vectors all the way through the backend so we don't need
-  // to scalarize here.
-  for (size_t i = 0; i < records.size(); i++) {
-    struct PartiallyConstructedSafepointRecord &info = records[i];
-    Instruction *statepoint = toUpdate[i].getInstruction();
-    splitVectorValues(cast<Instruction>(statepoint), info.liveset, DT);
-  }
-
   // B) Find the base pointers for each live pointer
   /* scope for caching */ {
     // Cache the 'defining value' relation used in the computation and
@@ -2146,6 +2203,18 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, Pass *P,
   }
   holders.clear();
 
+  // Do a limited scalarization of any live at safepoint vector values which
+  // contain pointers.  This enables this pass to run after vectorization at
+  // the cost of some possible performance loss.  TODO: it would be nice to
+  // natively support vectors all the way through the backend so we don't need
+  // to scalarize here.
+  for (size_t i = 0; i < records.size(); i++) {
+    struct PartiallyConstructedSafepointRecord &info = records[i];
+    Instruction *statepoint = toUpdate[i].getInstruction();
+    splitVectorValues(cast<Instruction>(statepoint), info.liveset,
+                      info.PointerToBase, DT);
+  }
+
   // In order to reduce live set of statepoint we might choose to rematerialize
   // some values instead of relocating them. This is purelly an optimization and
   // does not influence correctness.
index 0a4456a68353a368fc37cdec50c3557277a05a41..26ad73737adc63ebf5fe38eb7b35dbd076eb1c46 100644 (file)
@@ -105,8 +105,6 @@ define <2 x i64 addrspace(1)*> @test5(i64 addrspace(1)* %p)
 ; CHECK-NEXT: bitcast
 ; CHECK-NEXT: gc.relocate
 ; CHECK-NEXT: bitcast
-; CHECK-NEXT: gc.relocate
-; CHECK-NEXT: bitcast
 ; CHECK-NEXT: insertelement
 ; CHECK-NEXT: insertelement
 ; CHECK-NEXT: ret <2 x i64 addrspace(1)*> %7
@@ -116,6 +114,48 @@ entry:
   ret <2 x i64 addrspace(1)*> %vec
 }
 
+
+; A base vector from a load
+define <2 x i64 addrspace(1)*> @test6(i1 %cnd, <2 x i64 addrspace(1)*>* %ptr) 
+    gc "statepoint-example" {
+; CHECK-LABEL: test6
+; CHECK-LABEL: merge:
+; CHECK-NEXT: = phi
+; CHECK-NEXT: = phi
+; CHECK-NEXT: extractelement
+; CHECK-NEXT: extractelement
+; CHECK-NEXT: extractelement
+; CHECK-NEXT: extractelement
+; CHECK-NEXT: gc.statepoint
+; CHECK-NEXT: gc.relocate
+; CHECK-NEXT: bitcast
+; CHECK-NEXT: gc.relocate
+; CHECK-NEXT: bitcast
+; CHECK-NEXT: gc.relocate
+; CHECK-NEXT: bitcast
+; CHECK-NEXT: gc.relocate
+; CHECK-NEXT: bitcast
+; CHECK-NEXT: insertelement
+; CHECK-NEXT: insertelement
+; CHECK-NEXT: insertelement
+; CHECK-NEXT: insertelement
+; CHECK-NEXT: ret <2 x i64 addrspace(1)*>
+entry:
+  br i1 %cnd, label %taken, label %untaken
+taken:
+  %obja = load <2 x i64 addrspace(1)*>, <2 x i64 addrspace(1)*>* %ptr
+  br label %merge
+untaken:
+  %objb = load <2 x i64 addrspace(1)*>, <2 x i64 addrspace(1)*>* %ptr
+  br label %merge
+
+merge:
+  %obj = phi <2 x i64 addrspace(1)*> [%obja, %taken], [%objb, %untaken]
+  %safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 0)
+  ret <2 x i64 addrspace(1)*> %obj
+}
+
+
 declare void @do_safepoint()
 
 declare i32 @llvm.experimental.gc.statepoint.p0f_isVoidf(i64, i32, void ()*, i32, i32, ...)