[RewriteStatepointsForGC] Extend base pointer inference to handle insertelement
authorPhilip Reames <listmail@philipreames.com>
Wed, 9 Sep 2015 23:40:12 +0000 (23:40 +0000)
committerPhilip Reames <listmail@philipreames.com>
Wed, 9 Sep 2015 23:40:12 +0000 (23:40 +0000)
This change is simply enhancing the existing inference algorithm to handle insertelement instructions by conservatively inserting a new instruction to propagate the vector of associated base pointers. In the process, I'm ripping out the peephole optimizations which mostly helped cover the fact this hadn't been done.

Note that most of the newly inserted nodes will be nearly immediately removed by the post insertion optimization pass introduced in 246718. Arguably, we should be trying harder to avoid the malloc traffic here, but I'd rather get the code correct, then worry about compile time.

Unlike previous extensions of the algorithm to handle more case, I discovered the existing code was causing miscompiles in some cases. In particular, we had an implicit assumption that the peephole covered *all* insert element instructions, so if we had a value directly based on a insert element the peephole didn't cover, we proceeded as if it were a base anyways. Not good. I believe we had the same issue with shufflevector which is why I adjusted the predicate for them as well.

Differential Revision: http://reviews.llvm.org/D12583

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@247210 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
test/Transforms/RewriteStatepointsForGC/base-vector.ll

index 031f40e2caa6c0a11adc2700080893b2ce22d0d9..c70619cf27c5490522aa1e13dfabdc2a0c8cd27c 100644 (file)
@@ -328,7 +328,7 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I);
 /// If the later, the return pointer is a BDV (or possibly a base) for the
 /// particular element in 'I'.  
 static BaseDefiningValueResult
-findBaseDefiningValueOfVector(Value *I, Value *Index = nullptr) {
+findBaseDefiningValueOfVector(Value *I) {
   assert(I->getType()->isVectorTy() &&
          cast<VectorType>(I->getType())->getElementType()->isPointerTy() &&
          "Illegal to ask for the base pointer of a non-pointer type");
@@ -362,35 +362,12 @@ findBaseDefiningValueOfVector(Value *I, Value *Index = nullptr) {
   
   if (isa<LoadInst>(I))
     return BaseDefiningValueResult(I, true);
-  
-  // For an insert element, we might be able to look through it if we know
-  // something about the indexes.
-  if (InsertElementInst *IEI = dyn_cast<InsertElementInst>(I)) {
-    if (Index) {
-      Value *InsertIndex = IEI->getOperand(2);
-      // This index is inserting the value, look for its BDV
-      if (InsertIndex == Index)
-        return findBaseDefiningValue(IEI->getOperand(1));
-      // Both constant, and can't be equal per above. This insert is definitely
-      // not relevant, look back at the rest of the vector and keep trying.
-      if (isa<ConstantInt>(Index) && isa<ConstantInt>(InsertIndex))
-        return findBaseDefiningValueOfVector(IEI->getOperand(0), Index);
-    }
 
-    // If both inputs to the insertelement are known bases, then so is the
-    // insertelement itself.  NOTE: This should be handled within the generic
-    // base pointer inference code and after http://reviews.llvm.org/D12583,
-    // will be.  However, when strengthening asserts I needed to add this to
-    // keep an existing test passing which was 'working'. FIXME
-    if (findBaseDefiningValue(IEI->getOperand(0)).IsKnownBase &&
-        findBaseDefiningValue(IEI->getOperand(1)).IsKnownBase)
-      return BaseDefiningValueResult(IEI, true);
-    
+  if (isa<InsertElementInst>(I))
     // We don't know whether this vector contains entirely base pointers or
     // not.  To be conservatively correct, we treat it as a BDV and will
     // duplicate code as needed to construct a parallel vector of bases.
-    return BaseDefiningValueResult(IEI, false);
-  }
+    return BaseDefiningValueResult(I, false);
 
   if (isa<ShuffleVectorInst>(I))
     // We don't know whether this vector contains entirely base pointers or
@@ -528,27 +505,11 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I) {
   // We may need to insert a parallel instruction to extract the appropriate
   // element out of the base vector corresponding to the input. Given this,
   // it's analogous to the phi and select case even though it's not a merge.
-  if (auto *EEI = dyn_cast<ExtractElementInst>(I)) {
-    Value *VectorOperand = EEI->getVectorOperand();
-    Value *Index = EEI->getIndexOperand();
-    auto VecResult = findBaseDefiningValueOfVector(VectorOperand, Index);
-    Value *VectorBase = VecResult.BDV;
-    if (VectorBase->getType()->isPointerTy())
-      // We found a BDV for this specific element with the vector.  This is an
-      // optimization, but in practice it covers most of the useful cases
-      // created via scalarization. Note: The peephole optimization here is
-      // currently needed for correctness since the general algorithm doesn't
-      // yet handle insertelements.  That will change shortly.
-      return BaseDefiningValueResult(VectorBase, VecResult.IsKnownBase);
-    else {
-      assert(VectorBase->getType()->isVectorTy());
-      // Otherwise, we have an instruction which potentially produces a
-      // derived pointer and we need findBasePointers to clone code for us
-      // such that we can create an instruction which produces the
-      // accompanying base pointer.
-      return BaseDefiningValueResult(I, VecResult.IsKnownBase);
-    }
-  }
+  if (isa<ExtractElementInst>(I))
+    // Note: There a lot of obvious peephole cases here.  This are deliberately
+    // handled after the main base pointer inference algorithm to make writing
+    // test cases to exercise that code easier.
+    return BaseDefiningValueResult(I, false);
 
   // The last two cases here don't return a base pointer.  Instead, they
   // return a value which dynamically selects from among several base
@@ -587,7 +548,9 @@ static Value *findBaseOrBDV(Value *I, DefiningValueMapTy &Cache) {
 /// Given the result of a call to findBaseDefiningValue, or findBaseOrBDV,
 /// is it known to be a base pointer?  Or do we need to continue searching.
 static bool isKnownBaseResult(Value *V) {
-  if (!isa<PHINode>(V) && !isa<SelectInst>(V) && !isa<ExtractElementInst>(V)) {
+  if (!isa<PHINode>(V) && !isa<SelectInst>(V) &&
+      !isa<ExtractElementInst>(V) && !isa<InsertElementInst>(V) &&
+      !isa<ShuffleVectorInst>(V)) {
     // no recursion possible
     return true;
   }
@@ -755,7 +718,8 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
 
 #ifndef NDEBUG
   auto isExpectedBDVType = [](Value *BDV) {
-    return isa<PHINode>(BDV) || isa<SelectInst>(BDV) || isa<ExtractElementInst>(BDV);
+    return isa<PHINode>(BDV) || isa<SelectInst>(BDV) ||
+           isa<ExtractElementInst>(BDV) || isa<InsertElementInst>(BDV);
   };
 #endif
 
@@ -795,10 +759,12 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
         visitIncomingValue(Sel->getFalseValue());
       } else if (auto *EE = dyn_cast<ExtractElementInst>(Current)) {
         visitIncomingValue(EE->getVectorOperand());
+      } else if (auto *IE = dyn_cast<InsertElementInst>(Current)) {
+        visitIncomingValue(IE->getOperand(0)); // vector operand
+        visitIncomingValue(IE->getOperand(1)); // scalar operand
       } else {
-        // There are two classes of instructions we know we don't handle.
-        assert(isa<ShuffleVectorInst>(Current) ||
-               isa<InsertElementInst>(Current));
+        // There is one known class of instructions we know we don't handle.
+        assert(isa<ShuffleVectorInst>(Current));
         llvm_unreachable("unimplemented instruction case");
       }
     }
@@ -849,11 +815,16 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
       } else if (PHINode *Phi = dyn_cast<PHINode>(v)) {
         for (Value *Val : Phi->incoming_values())
           calculateMeet.meetWith(getStateForInput(Val));
-      } else {
+      } else if (auto *EE = dyn_cast<ExtractElementInst>(v)) {
         // The 'meet' for an extractelement is slightly trivial, but it's still
         // useful in that it drives us to conflict if our input is.
-        auto *EE = cast<ExtractElementInst>(v);
         calculateMeet.meetWith(getStateForInput(EE->getVectorOperand()));
+      } else {
+        // Given there's a inherent type mismatch between the operands, will
+        // *always* produce Conflict.
+        auto *IE = cast<InsertElementInst>(v);
+        calculateMeet.meetWith(getStateForInput(IE->getOperand(0)));
+        calculateMeet.meetWith(getStateForInput(IE->getOperand(1)));
       }
 
       BDVState oldState = states[v];
@@ -899,6 +870,13 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
       BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {}));
       states[I] = BDVState(BDVState::Base, BaseInst);
     }
+
+    // Since we're joining a vector and scalar base, they can never be the
+    // same.  As a result, we should always see insert element having reached
+    // the conflict state.
+    if (isa<InsertElementInst>(I)) {
+      assert(State.isConflict());
+    }
     
     if (!State.isConflict())
       continue;
@@ -920,14 +898,22 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
           (I->getName() + ".base").str() : "base_select";
         return SelectInst::Create(Sel->getCondition(), Undef,
                                   Undef, Name, Sel);
-      } else {
-        auto *EE = cast<ExtractElementInst>(I);
+      } else if (auto *EE = dyn_cast<ExtractElementInst>(I)) {
         UndefValue *Undef = UndefValue::get(EE->getVectorOperand()->getType());
         std::string Name = I->hasName() ?
           (I->getName() + ".base").str() : "base_ee";
         return ExtractElementInst::Create(Undef, EE->getIndexOperand(), Name,
                                           EE);
+      } else {
+        auto *IE = cast<InsertElementInst>(I);
+        UndefValue *VecUndef = UndefValue::get(IE->getOperand(0)->getType());
+        UndefValue *ScalarUndef = UndefValue::get(IE->getOperand(1)->getType());
+        std::string Name = I->hasName() ?
+          (I->getName() + ".base").str() : "base_ie";
+        return InsertElementInst::Create(VecUndef, ScalarUndef,
+                                         IE->getOperand(2), Name, IE);
       }
+
     };
     Instruction *BaseInst = MakeBaseInstPlaceholder(I);
     // Add metadata marking this as a base value
@@ -1029,14 +1015,31 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
         Value *Base = getBaseForInput(InVal, BaseSel);
         BaseSel->setOperand(i, Base);
       }
-    } else {
-      auto *BaseEE = cast<ExtractElementInst>(state.getBase());
+    } else if (auto *BaseEE = dyn_cast<ExtractElementInst>(state.getBase())) {
       Value *InVal = cast<ExtractElementInst>(v)->getVectorOperand();
       // Find the instruction which produces the base for each input.  We may
       // need to insert a bitcast.
       Value *Base = getBaseForInput(InVal, BaseEE);
       BaseEE->setOperand(0, Base);
+    } else {
+      auto *BaseIE = cast<InsertElementInst>(state.getBase());
+      auto *BdvIE = cast<InsertElementInst>(v);
+      auto UpdateOperand = [&](int OperandIdx) {
+        Value *InVal = BdvIE->getOperand(OperandIdx);
+        Value *Base = findBaseOrBDV(InVal, cache);
+        if (!isKnownBaseResult(Base)) {
+          // Either conflict or base.
+          assert(states.count(Base));
+          Base = states[Base].getBase();
+          assert(Base != nullptr && "unknown BDVState!");
+        }
+        assert(Base && "can't be null");
+        BaseIE->setOperand(OperandIdx, Base);
+      };
+      UpdateOperand(0); // vector operand
+      UpdateOperand(1); // scalar operand
     }
+
   }
 
   // Now that we're done with the algorithm, see if we can optimize the 
index 95011e86413282ab9c7a9b8cad3965dee0c4af11..752e495106182771a3505d5d1fd55c238adde179 100644 (file)
@@ -59,7 +59,7 @@ entry:
 ; CHECK: extractelement
 ; CHECK: statepoint
 ; CHECK: gc.relocate
-; CHECK-DAG: ; (%ptr, %obj)
+; CHECK-DAG: (%obj, %obj)
    %safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 0)
   ret i64 addrspace(1)* %obj
 }
@@ -80,6 +80,88 @@ entry:
   ret i64 addrspace(1)* %obj
 }
 
+declare void @use(i64 addrspace(1)*)
+
+; When we can optimize an extractelement from a known
+; index and avoid introducing new base pointer instructions
+define void @test5(i1 %cnd, i64 addrspace(1)* %obj)
+    gc "statepoint-example" {
+; CHECK-LABEL: @test5
+; CHECK: gc.relocate
+; CHECK-DAG: (%obj, %bdv)
+entry:
+  %gep = getelementptr i64, i64 addrspace(1)* %obj, i64 1
+  %vec = insertelement <2 x i64 addrspace(1)*> undef, i64 addrspace(1)* %gep, i32 0
+  %bdv = extractelement <2 x i64 addrspace(1)*> %vec, i32 0
+  %safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 5, i32 0, i32 -1, i32 0, i32 0, i32 0)
+  call void @use(i64 addrspace(1)* %bdv)
+  ret void
+}
+
+; When we fundementally have to duplicate
+define void @test6(i1 %cnd, i64 addrspace(1)* %obj, i64 %idx)
+    gc "statepoint-example" {
+; CHECK-LABEL: @test6
+; CHECK: %gep = getelementptr i64, i64 addrspace(1)* %obj, i64 1
+; CHECK: %vec.base = insertelement <2 x i64 addrspace(1)*> undef, i64 addrspace(1)* %obj, i32 0, !is_base_value !0
+; CHECK: %vec = insertelement <2 x i64 addrspace(1)*> undef, i64 addrspace(1)* %gep, i32 0
+; CHECK: %bdv.base = extractelement <2 x i64 addrspace(1)*> %vec.base, i64 %idx, !is_base_value !0
+; CHECK:  %bdv = extractelement <2 x i64 addrspace(1)*> %vec, i64 %idx
+; CHECK: gc.statepoint
+; CHECK: gc.relocate
+; CHECK-DAG: (%bdv.base, %bdv)
+entry:
+  %gep = getelementptr i64, i64 addrspace(1)* %obj, i64 1
+  %vec = insertelement <2 x i64 addrspace(1)*> undef, i64 addrspace(1)* %gep, i32 0
+  %bdv = extractelement <2 x i64 addrspace(1)*> %vec, i64 %idx
+  %safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 5, i32 0, i32 -1, i32 0, i32 0, i32 0)
+  call void @use(i64 addrspace(1)* %bdv)
+  ret void
+}
+
+; A more complicated example involving vector and scalar bases.
+; This is derived from a failing test case when we didn't have correct
+; insertelement handling.
+define i64 addrspace(1)* @test7(i1 %cnd, i64 addrspace(1)* %obj, 
+                                i64 addrspace(1)* %obj2)
+    gc "statepoint-example" {
+; CHECK-LABEL: @test7
+entry:
+  %vec = insertelement <2 x i64 addrspace(1)*> undef, i64 addrspace(1)* %obj2, i32 0
+  br label %merge1
+merge1:
+; CHECK-LABEL: merge1:
+; CHECK: vec2.base
+; CHECK: vec2
+; CHECK: gep
+; CHECK: vec3.base
+; CHECK: vec3
+  %vec2 = phi <2 x i64 addrspace(1)*> [ %vec, %entry ], [ %vec3, %merge1 ]
+  %gep = getelementptr i64, i64 addrspace(1)* %obj2, i64 1
+  %vec3 = insertelement <2 x i64 addrspace(1)*> undef, i64 addrspace(1)* %gep, i32 0
+  br i1 %cnd, label %merge1, label %next1
+next1:
+; CHECK-LABEL: next1:
+; CHECK: bdv.base = 
+; CHECK: bdv = 
+  %bdv = extractelement <2 x i64 addrspace(1)*> %vec2, i32 0
+  br label %merge
+merge:
+; CHECK-LABEL: merge:
+; CHECK: %objb.base
+; CHECK: %objb
+; CHECK: gc.statepoint
+; CHECK: gc.relocate
+; CHECK-DAG: (%objb.base, %objb)
+
+  %objb = phi i64 addrspace(1)* [ %obj, %next1 ], [ %bdv, %merge ]
+  br i1 %cnd, label %merge, label %next
+next:
+  %safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 5, i32 0, i32 -1, i32 0, i32 0, i32 0)
+  ret i64 addrspace(1)* %objb
+}
+
+
 declare void @do_safepoint()
 
 declare i32 @llvm.experimental.gc.statepoint.p0f_isVoidf(i64, i32, void ()*, i32, i32, ...)