[RewriteStatepointsForGC] Reduce the number of new instructions for base pointers
[oota-llvm.git] / lib / Transforms / Scalar / RewriteStatepointsForGC.cpp
index 505070552c188e10d6eb444a1e1fbe145863dc3d..2bd337814b4e44adc68b3b12e4cb11aa2e199555 100644 (file)
@@ -14,6 +14,7 @@
 
 #include "llvm/Pass.h"
 #include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/Statistic.h"
@@ -30,6 +31,7 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Statepoint.h"
 #include "llvm/IR/Value.h"
 #include "llvm/IR/Verifier.h"
@@ -74,13 +76,27 @@ static cl::opt<bool, true> ClobberNonLiveOverride("rs4gc-clobber-non-live",
                                                   cl::Hidden);
 
 namespace {
-struct RewriteStatepointsForGC : public FunctionPass {
+struct RewriteStatepointsForGC : public ModulePass {
   static char ID; // Pass identification, replacement for typeid
 
-  RewriteStatepointsForGC() : FunctionPass(ID) {
+  RewriteStatepointsForGC() : ModulePass(ID) {
     initializeRewriteStatepointsForGCPass(*PassRegistry::getPassRegistry());
   }
-  bool runOnFunction(Function &F) override;
+  bool runOnFunction(Function &F);
+  bool runOnModule(Module &M) override {
+    bool Changed = false;
+    for (Function &F : M)
+      Changed |= runOnFunction(F);
+
+    if (Changed) {
+      // stripDereferenceabilityInfo asserts that shouldRewriteStatepointsIn
+      // returns true for at least one function in the module.  Since at least
+      // one function changed, we know that the precondition is satisfied.
+      stripDereferenceabilityInfo(M);
+    }
+
+    return Changed;
+  }
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     // We add and rewrite a bunch of instructions, but don't really do much
@@ -88,12 +104,26 @@ struct RewriteStatepointsForGC : public FunctionPass {
     AU.addRequired<DominatorTreeWrapperPass>();
     AU.addRequired<TargetTransformInfoWrapperPass>();
   }
+
+  /// The IR fed into RewriteStatepointsForGC may have had attributes implying
+  /// dereferenceability that are no longer valid/correct after
+  /// RewriteStatepointsForGC has run.  This is because semantically, after
+  /// RewriteStatepointsForGC runs, all calls to gc.statepoint "free" the entire
+  /// heap.  stripDereferenceabilityInfo (conservatively) restores correctness
+  /// by erasing all attributes in the module that externally imply
+  /// dereferenceability.
+  ///
+  void stripDereferenceabilityInfo(Module &M);
+
+  // Helpers for stripDereferenceabilityInfo
+  void stripDereferenceabilityInfoFromBody(Function &F);
+  void stripDereferenceabilityInfoFromPrototype(Function &F);
 };
 } // namespace
 
 char RewriteStatepointsForGC::ID = 0;
 
-FunctionPass *llvm::createRewriteStatepointsForGCPass() {
+ModulePass *llvm::createRewriteStatepointsForGCPass() {
   return new RewriteStatepointsForGC();
 }
 
@@ -135,7 +165,7 @@ typedef DenseSet<llvm::Value *> StatepointLiveSetTy;
 typedef DenseMap<Instruction *, Value *> RematerializedValueMapTy;
 
 struct PartiallyConstructedSafepointRecord {
-  /// The set of values known to be live accross this safepoint
+  /// The set of values known to be live across this safepoint
   StatepointLiveSetTy liveset;
 
   /// Mapping from live pointers to a base-defining-value
@@ -168,8 +198,8 @@ static void findLiveSetAtInst(Instruction *inst, GCPtrLivenessData &Data,
 // TODO: Once we can get to the GCStrategy, this becomes
 // Optional<bool> isGCManagedPointer(const Value *V) const override {
 
-static bool isGCPointerType(const Type *T) {
-  if (const PointerType *PT = dyn_cast<PointerType>(T))
+static bool isGCPointerType(Type *T) {
+  if (auto *PT = dyn_cast<PointerType>(T))
     // For the sake of this example GC, we arbitrarily pick addrspace(1) as our
     // GC managed heap.  We know that a pointer into this heap needs to be
     // updated and that no other pointer does.
@@ -245,7 +275,7 @@ static void analyzeParsePointLiveness(
 
   if (PrintLiveSet) {
     // Note: This output is used by several of the test cases
-    // The order of elemtns in a set is not stable, put them in a vec and sort
+    // The order of elements in a set is not stable, put them in a vec and sort
     // by name
     SmallVector<Value *, 64> temp;
     temp.insert(temp.end(), liveset.begin(), liveset.end());
@@ -265,12 +295,17 @@ static void analyzeParsePointLiveness(
 
 static Value *findBaseDefiningValue(Value *I);
 
-/// If we can trivially determine that the index specified in the given vector
-/// is a base pointer, return it.  In cases where the entire vector is known to
-/// consist of base pointers, the entire vector will be returned.  This
-/// indicates that the relevant extractelement is a valid base pointer and
-/// should be used directly.
-static Value *findBaseOfVector(Value *I, Value *Index) {
+/// Return a base defining value for the 'Index' element of the given vector
+/// instruction 'I'.  If Index is null, returns a BDV for the entire vector
+/// 'I'.  As an optimization, this method will try to determine when the 
+/// element is known to already be a base pointer.  If this can be established,
+/// the second value in the returned pair will be true.  Note that either a
+/// vector or a pointer typed value can be returned.  For the former, the
+/// vector returned is a BDV (and possibly a base) of the entire vector 'I'.
+/// If the later, the return pointer is a BDV (or possibly a base) for the
+/// particular element in 'I'.  
+static std::pair<Value *, bool>
+findBaseDefiningValueOfVector(Value *I, Value *Index = nullptr) {
   assert(I->getType()->isVectorTy() &&
          cast<VectorType>(I->getType())->getElementType()->isPointerTy() &&
          "Illegal to ask for the base pointer of a non-pointer type");
@@ -280,7 +315,7 @@ static Value *findBaseOfVector(Value *I, Value *Index) {
 
   if (isa<Argument>(I))
     // An incoming argument to the function is a base pointer
-    return I;
+    return std::make_pair(I, true);
 
   // We shouldn't see the address of a global as a vector value?
   assert(!isa<GlobalVariable>(I) &&
@@ -291,7 +326,7 @@ static Value *findBaseOfVector(Value *I, Value *Index) {
   if (isa<UndefValue>(I))
     // utterly meaningless, but useful for dealing with partially optimized
     // code.
-    return I;
+    return std::make_pair(I, true);
 
   // Due to inheritance, this must be _after_ the global variable and undef
   // checks
@@ -299,60 +334,60 @@ static Value *findBaseOfVector(Value *I, Value *Index) {
     assert(!isa<GlobalVariable>(I) && !isa<UndefValue>(I) &&
            "order of checks wrong!");
     assert(Con->isNullValue() && "null is the only case which makes sense");
-    return Con;
+    return std::make_pair(Con, true);
   }
-
+  
   if (isa<LoadInst>(I))
-    return I;
-
+    return std::make_pair(I, true);
+  
   // For an insert element, we might be able to look through it if we know
-  // something about the indexes, but if the indices are arbitrary values, we
-  // can't without much more extensive scalarization. 
+  // something about the indexes.
   if (InsertElementInst *IEI = dyn_cast<InsertElementInst>(I)) {
-    Value *InsertIndex = IEI->getOperand(2);
-    // This index is inserting the value, look for it's base
-    if (InsertIndex == Index)
-      return findBaseDefiningValue(IEI->getOperand(1));
-    // Both constant, and can't be equal per above. This insert is definitely
-    // not relevant, look back at the rest of the vector and keep trying.  
-    if (isa<ConstantInt>(Index) && isa<ConstantInt>(InsertIndex))
-      return findBaseOfVector(IEI->getOperand(0), Index);
-  }
-  
-  // Note: This code is currently rather incomplete.  We are essentially only
-  // handling cases where the vector element is trivially a base pointer.  We
-  // need to update the entire base pointer construction algorithm to know how
-  // to track vector elements and potentially scalarize, but the case which
-  // would motivate the work hasn't shown up in real workloads yet.
-  llvm_unreachable("no base found for vector element");
+    if (Index) {
+      Value *InsertIndex = IEI->getOperand(2);
+      // This index is inserting the value, look for its BDV
+      if (InsertIndex == Index)
+        return std::make_pair(findBaseDefiningValue(IEI->getOperand(1)), false);
+      // Both constant, and can't be equal per above. This insert is definitely
+      // not relevant, look back at the rest of the vector and keep trying.
+      if (isa<ConstantInt>(Index) && isa<ConstantInt>(InsertIndex))
+        return findBaseDefiningValueOfVector(IEI->getOperand(0), Index);
+    }
+    
+    // We don't know whether this vector contains entirely base pointers or
+    // not.  To be conservatively correct, we treat it as a BDV and will
+    // duplicate code as needed to construct a parallel vector of bases.
+    return std::make_pair(IEI, false);
+  }
+
+  if (isa<ShuffleVectorInst>(I))
+    // We don't know whether this vector contains entirely base pointers or
+    // not.  To be conservatively correct, we treat it as a BDV and will
+    // duplicate code as needed to construct a parallel vector of bases.
+    // TODO: There a number of local optimizations which could be applied here
+    // for particular sufflevector patterns.
+    return std::make_pair(I, false);
+
+  // A PHI or Select is a base defining value.  The outer findBasePointer
+  // algorithm is responsible for constructing a base value for this BDV.
+  assert((isa<SelectInst>(I) || isa<PHINode>(I)) &&
+         "unknown vector instruction - no base found for vector element");
+  return std::make_pair(I, false);
 }
 
+static bool isKnownBaseResult(Value *V);
+
 /// Helper function for findBasePointer - Will return a value which either a)
-/// defines the base pointer for the input or b) blocks the simple search
-/// (i.e. a PHI or Select of two derived pointers)
+/// defines the base pointer for the input, b) blocks the simple search
+/// (i.e. a PHI or Select of two derived pointers), or c) involves a change
+/// from pointer to vector type or back.
 static Value *findBaseDefiningValue(Value *I) {
+  if (I->getType()->isVectorTy())
+    return findBaseDefiningValueOfVector(I).first;
+  
   assert(I->getType()->isPointerTy() &&
          "Illegal to ask for the base pointer of a non-pointer type");
 
-  // This case is a bit of a hack - it only handles extracts from vectors which
-  // trivially contain only base pointers or cases where we can directly match
-  // the index of the original extract element to an insertion into the vector.
-  // See note inside the function for how to improve this.
-  if (auto *EEI = dyn_cast<ExtractElementInst>(I)) {
-    Value *VectorOperand = EEI->getVectorOperand();
-    Value *Index = EEI->getIndexOperand();
-    Value *VectorBase = findBaseOfVector(VectorOperand, Index);
-    // If the result returned is a vector, we know the entire vector must
-    // contain base pointers.  In that case, the extractelement is a valid base
-    // for this value.
-    if (VectorBase->getType()->isVectorTy())
-      return EEI;
-    // Otherwise, we needed to look through the vector to find the base for
-    // this particular element.
-    assert(VectorBase->getType()->isPointerTy());
-    return VectorBase;
-  }
-
   if (isa<Argument>(I))
     // An incoming argument to the function is a base pointer
     // We should have never reached here if this argument isn't an gc value
@@ -434,7 +469,7 @@ static Value *findBaseDefiningValue(Value *I) {
     return I;
 
   // I have absolutely no idea how to implement this part yet.  It's not
-  // neccessarily hard, I just haven't really looked at it yet.
+  // necessarily hard, I just haven't really looked at it yet.
   assert(!isa<LandingPadInst>(I) && "Landing Pad is unimplemented");
 
   if (isa<AtomicCmpXchgInst>(I))
@@ -457,8 +492,35 @@ static Value *findBaseDefiningValue(Value *I) {
   assert(!isa<InsertValueInst>(I) &&
          "Base pointer for a struct is meaningless");
 
+  // An extractelement produces a base result exactly when it's input does.
+  // We may need to insert a parallel instruction to extract the appropriate
+  // element out of the base vector corresponding to the input. Given this,
+  // it's analogous to the phi and select case even though it's not a merge.
+  if (auto *EEI = dyn_cast<ExtractElementInst>(I)) {
+    Value *VectorOperand = EEI->getVectorOperand();
+    Value *Index = EEI->getIndexOperand();
+    std::pair<Value *, bool> pair =
+      findBaseDefiningValueOfVector(VectorOperand, Index);
+    Value *VectorBase = pair.first;
+    if (VectorBase->getType()->isPointerTy())
+      // We found a BDV for this specific element with the vector.  This is an
+      // optimization, but in practice it covers most of the useful cases
+      // created via scalarization. Note: The peephole optimization here is
+      // currently needed for correctness since the general algorithm doesn't
+      // yet handle insertelements.  That will change shortly.
+      return VectorBase;
+    else {
+      assert(VectorBase->getType()->isVectorTy());
+      // Otherwise, we have an instruction which potentially produces a
+      // derived pointer and we need findBasePointers to clone code for us
+      // such that we can create an instruction which produces the
+      // accompanying base pointer.
+      return EEI;
+    }
+  }
+
   // The last two cases here don't return a base pointer.  Instead, they
-  // return a value which dynamically selects from amoung several base
+  // return a value which dynamically selects from among several base
   // derived pointers (each with it's own base potentially).  It's the job of
   // the caller to resolve these.
   assert((isa<SelectInst>(I) || isa<PHINode>(I)) &&
@@ -471,13 +533,10 @@ static Value *findBaseDefiningValueCached(Value *I, DefiningValueMapTy &Cache) {
   Value *&Cached = Cache[I];
   if (!Cached) {
     Cached = findBaseDefiningValue(I);
+    DEBUG(dbgs() << "fBDV-cached: " << I->getName() << " -> "
+                 << Cached->getName() << "\n");
   }
   assert(Cache[I] != nullptr);
-
-  if (TraceLSP) {
-    dbgs() << "fBDV-cached: " << I->getName() << " -> " << Cached->getName()
-           << "\n";
-  }
   return Cached;
 }
 
@@ -497,7 +556,7 @@ static Value *findBaseOrBDV(Value *I, DefiningValueMapTy &Cache) {
 /// Given the result of a call to findBaseDefiningValue, or findBaseOrBDV,
 /// is it known to be a base pointer?  Or do we need to continue searching.
 static bool isKnownBaseResult(Value *V) {
-  if (!isa<PHINode>(V) && !isa<SelectInst>(V)) {
+  if (!isa<PHINode>(V) && !isa<SelectInst>(V) && !isa<ExtractElementInst>(V)) {
     // no recursion possible
     return true;
   }
@@ -512,17 +571,19 @@ static bool isKnownBaseResult(Value *V) {
   return false;
 }
 
-// TODO: find a better name for this
 namespace {
-class PhiState {
+/// Models the state of a single base defining value in the findBasePointer
+/// algorithm for determining where a new instruction is needed to propagate
+/// the base of this BDV.
+class BDVState {
 public:
   enum Status { Unknown, Base, Conflict };
 
-  PhiState(Status s, Value *b = nullptr) : status(s), base(b) {
+  BDVState(Status s, Value *b = nullptr) : status(s), base(b) {
     assert(status != Base || b);
   }
-  PhiState(Value *b) : status(Base), base(b) {}
-  PhiState() : status(Unknown), base(nullptr) {}
+  explicit BDVState(Value *b) : status(Base), base(b) {}
+  BDVState() : status(Unknown), base(nullptr) {}
 
   Status getStatus() const { return status; }
   Value *getBase() const { return base; }
@@ -531,15 +592,18 @@ public:
   bool isUnknown() const { return getStatus() == Unknown; }
   bool isConflict() const { return getStatus() == Conflict; }
 
-  bool operator==(const PhiState &other) const {
+  bool operator==(const BDVState &other) const {
     return base == other.base && status == other.status;
   }
 
-  bool operator!=(const PhiState &other) const { return !(*this == other); }
+  bool operator!=(const BDVState &other) const { return !(*this == other); }
 
-  void dump() {
-    errs() << status << " (" << base << " - "
-           << (base ? base->getName() : "nullptr") << "): ";
+  LLVM_DUMP_METHOD
+  void dump() const { print(dbgs()); dbgs() << '\n'; }
+  
+  void print(raw_ostream &OS) const {
+    OS << status << " (" << base << " - "
+       << (base ? base->getName() : "nullptr") << "): ";
   }
 
 private:
@@ -547,56 +611,48 @@ private:
   Value *base; // non null only if status == base
 };
 
-typedef DenseMap<Value *, PhiState> ConflictStateMapTy;
-// Values of type PhiState form a lattice, and this is a helper
+inline raw_ostream &operator<<(raw_ostream &OS, const BDVState &State) {
+  State.print(OS);
+  return OS;
+}
+
+
+typedef DenseMap<Value *, BDVState> ConflictStateMapTy;
+// Values of type BDVState form a lattice, and this is a helper
 // class that implementes the meet operation.  The meat of the meet
-// operation is implemented in MeetPhiStates::pureMeet
-class MeetPhiStates {
+// operation is implemented in MeetBDVStates::pureMeet
+class MeetBDVStates {
 public:
-  // phiStates is a mapping from PHINodes and SelectInst's to PhiStates.
-  explicit MeetPhiStates(const ConflictStateMapTy &phiStates)
-      : phiStates(phiStates) {}
-
-  // Destructively meet the current result with the base V.  V can
-  // either be a merge instruction (SelectInst / PHINode), in which
-  // case its status is looked up in the phiStates map; or a regular
-  // SSA value, in which case it is assumed to be a base.
-  void meetWith(Value *V) {
-    PhiState otherState = getStateForBDV(V);
-    assert((MeetPhiStates::pureMeet(otherState, currentResult) ==
-            MeetPhiStates::pureMeet(currentResult, otherState)) &&
-           "math is wrong: meet does not commute!");
-    currentResult = MeetPhiStates::pureMeet(otherState, currentResult);
+  /// Initializes the currentResult to the TOP state so that if can be met with
+  /// any other state to produce that state.
+  MeetBDVStates() {}
+
+  // Destructively meet the current result with the given BDVState
+  void meetWith(BDVState otherState) {
+    currentResult = meet(otherState, currentResult);
   }
 
-  PhiState getResult() const { return currentResult; }
+  BDVState getResult() const { return currentResult; }
 
 private:
-  const ConflictStateMapTy &phiStates;
-  PhiState currentResult;
-
-  /// Return a phi state for a base defining value.  We'll generate a new
-  /// base state for known bases and expect to find a cached state otherwise
-  PhiState getStateForBDV(Value *baseValue) {
-    if (isKnownBaseResult(baseValue)) {
-      return PhiState(baseValue);
-    } else {
-      return lookupFromMap(baseValue);
-    }
-  }
+  BDVState currentResult;
 
-  PhiState lookupFromMap(Value *V) {
-    auto I = phiStates.find(V);
-    assert(I != phiStates.end() && "lookup failed!");
-    return I->second;
+  /// Perform a meet operation on two elements of the BDVState lattice.
+  static BDVState meet(BDVState LHS, BDVState RHS) {
+    assert((pureMeet(LHS, RHS) == pureMeet(RHS, LHS)) &&
+           "math is wrong: meet does not commute!");
+    BDVState Result = pureMeet(LHS, RHS);
+    DEBUG(dbgs() << "meet of " << LHS << " with " << RHS
+                 << " produced " << Result << "\n");
+    return Result;
   }
 
-  static PhiState pureMeet(const PhiState &stateA, const PhiState &stateB) {
+  static BDVState pureMeet(const BDVState &stateA, const BDVState &stateB) {
     switch (stateA.getStatus()) {
-    case PhiState::Unknown:
+    case BDVState::Unknown:
       return stateB;
 
-    case PhiState::Base:
+    case BDVState::Base:
       assert(stateA.getBase() && "can't be null");
       if (stateB.isUnknown())
         return stateA;
@@ -606,12 +662,12 @@ private:
           assert(stateA == stateB && "equality broken!");
           return stateA;
         }
-        return PhiState(PhiState::Conflict);
+        return BDVState(BDVState::Conflict);
       }
       assert(stateB.isConflict() && "only three states!");
-      return PhiState(PhiState::Conflict);
+      return BDVState(BDVState::Conflict);
 
-    case PhiState::Conflict:
+    case BDVState::Conflict:
       return stateA;
     }
     llvm_unreachable("only three states!");
@@ -648,65 +704,82 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
   //
   // Note: A simpler form of this would be to add the conflict form of all
   // PHIs without running the optimistic algorithm.  This would be
-  // analougous to pessimistic data flow and would likely lead to an
+  // analogous to pessimistic data flow and would likely lead to an
   // overall worse solution.
 
+#ifndef NDEBUG
+  auto isExpectedBDVType = [](Value *BDV) {
+    return isa<PHINode>(BDV) || isa<SelectInst>(BDV) || isa<ExtractElementInst>(BDV);
+  };
+#endif
+
+  // Once populated, will contain a mapping from each potentially non-base BDV
+  // to a lattice value (described above) which corresponds to that BDV.
   ConflictStateMapTy states;
-  states[def] = PhiState();
   // Recursively fill in all phis & selects reachable from the initial one
   // for which we don't already know a definite base value for
-  // TODO: This should be rewritten with a worklist
-  bool done = false;
-  while (!done) {
-    done = true;
-    // Since we're adding elements to 'states' as we run, we can't keep
-    // iterators into the set.
-    SmallVector<Value *, 16> Keys;
-    Keys.reserve(states.size());
-    for (auto Pair : states) {
-      Value *V = Pair.first;
-      Keys.push_back(V);
-    }
-    for (Value *v : Keys) {
-      assert(!isKnownBaseResult(v) && "why did it get added?");
-      if (PHINode *phi = dyn_cast<PHINode>(v)) {
-        assert(phi->getNumIncomingValues() > 0 &&
-               "zero input phis are illegal");
-        for (Value *InVal : phi->incoming_values()) {
-          Value *local = findBaseOrBDV(InVal, cache);
-          if (!isKnownBaseResult(local) && states.find(local) == states.end()) {
-            states[local] = PhiState();
-            done = false;
-          }
-        }
-      } else if (SelectInst *sel = dyn_cast<SelectInst>(v)) {
-        Value *local = findBaseOrBDV(sel->getTrueValue(), cache);
-        if (!isKnownBaseResult(local) && states.find(local) == states.end()) {
-          states[local] = PhiState();
-          done = false;
-        }
-        local = findBaseOrBDV(sel->getFalseValue(), cache);
-        if (!isKnownBaseResult(local) && states.find(local) == states.end()) {
-          states[local] = PhiState();
-          done = false;
-        }
+  /* scope */ {
+    DenseSet<Value *> Visited;
+    SmallVector<Value*, 16> Worklist;
+    Worklist.push_back(def);
+    Visited.insert(def);
+    while (!Worklist.empty()) {
+      Value *Current = Worklist.pop_back_val();
+      assert(!isKnownBaseResult(Current) && "why did it get added?");
+
+      auto visitIncomingValue = [&](Value *InVal) {
+        Value *Base = findBaseOrBDV(InVal, cache);
+        if (isKnownBaseResult(Base))
+          // Known bases won't need new instructions introduced and can be
+          // ignored safely
+          return;
+        assert(isExpectedBDVType(Base) && "the only non-base values "
+               "we see should be base defining values");
+        if (Visited.insert(Base).second)
+          Worklist.push_back(Base);
+      };
+      if (PHINode *Phi = dyn_cast<PHINode>(Current)) {
+        for (Value *InVal : Phi->incoming_values())
+          visitIncomingValue(InVal);
+      } else if (SelectInst *Sel = dyn_cast<SelectInst>(Current)) {
+        visitIncomingValue(Sel->getTrueValue());
+        visitIncomingValue(Sel->getFalseValue());
+      } else if (auto *EE = dyn_cast<ExtractElementInst>(Current)) {
+        visitIncomingValue(EE->getVectorOperand());
+      } else {
+        // There are two classes of instructions we know we don't handle.
+        assert(isa<ShuffleVectorInst>(Current) ||
+               isa<InsertElementInst>(Current));
+        llvm_unreachable("unimplemented instruction case");
       }
     }
+    // The frontier of visited instructions are the ones we might need to
+    // duplicate, so fill in the starting state for the optimistic algorithm
+    // that follows.
+    for (Value *BDV : Visited) {
+      states[BDV] = BDVState();
+    }
   }
 
   if (TraceLSP) {
     errs() << "States after initialization:\n";
-    for (auto Pair : states) {
-      Instruction *v = cast<Instruction>(Pair.first);
-      PhiState state = Pair.second;
-      state.dump();
-      v->dump();
-    }
+    for (auto Pair : states)
+      dbgs() << " " << Pair.second << " for " << *Pair.first << "\n";
   }
 
   // TODO: come back and revisit the state transitions around inputs which
   // have reached conflict state.  The current version seems too conservative.
 
+  // Return a phi state for a base defining value.  We'll generate a new
+  // base state for known bases and expect to find a cached state otherwise.
+  auto getStateForBDV = [&](Value *baseValue) {
+    if (isKnownBaseResult(baseValue))
+      return BDVState(baseValue);
+    auto I = states.find(baseValue);
+    assert(I != states.end() && "lookup failed!");
+    return I->second;
+  };
+
   bool progress = true;
   while (progress) {
 #ifndef NDEBUG
@@ -715,18 +788,33 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
     progress = false;
     // We're only changing keys in this loop, thus safe to keep iterators
     for (auto Pair : states) {
-      MeetPhiStates calculateMeet(states);
       Value *v = Pair.first;
       assert(!isKnownBaseResult(v) && "why did it get added?");
+
+      // Given an input value for the current instruction, return a BDVState
+      // instance which represents the BDV of that value.
+      auto getStateForInput = [&](Value *V) mutable {
+        Value *BDV = findBaseOrBDV(V, cache);
+        return getStateForBDV(BDV);
+      };
+
+      MeetBDVStates calculateMeet;
       if (SelectInst *select = dyn_cast<SelectInst>(v)) {
-        calculateMeet.meetWith(findBaseOrBDV(select->getTrueValue(), cache));
-        calculateMeet.meetWith(findBaseOrBDV(select->getFalseValue(), cache));
-      } else
-        for (Value *Val : cast<PHINode>(v)->incoming_values())
-          calculateMeet.meetWith(findBaseOrBDV(Val, cache));
-
-      PhiState oldState = states[v];
-      PhiState newState = calculateMeet.getResult();
+        calculateMeet.meetWith(getStateForInput(select->getTrueValue()));
+        calculateMeet.meetWith(getStateForInput(select->getFalseValue()));
+      } else if (PHINode *Phi = dyn_cast<PHINode>(v)) {
+        for (Value *Val : Phi->incoming_values())
+          calculateMeet.meetWith(getStateForInput(Val));
+      } else {
+        // The 'meet' for an extractelement is slightly trivial, but it's still
+        // useful in that it drives us to conflict if our input is.
+        auto *EE = cast<ExtractElementInst>(v);
+        calculateMeet.meetWith(getStateForInput(EE->getVectorOperand()));
+      }
+
+
+      BDVState oldState = states[v];
+      BDVState newState = calculateMeet.getResult();
       if (oldState != newState) {
         progress = true;
         states[v] = newState;
@@ -739,12 +827,8 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
 
   if (TraceLSP) {
     errs() << "States after meet iteration:\n";
-    for (auto Pair : states) {
-      Instruction *v = cast<Instruction>(Pair.first);
-      PhiState state = Pair.second;
-      state.dump();
-      v->dump();
-    }
+    for (auto Pair : states)
+      dbgs() << " " << Pair.second << " for " << *Pair.first << "\n";
   }
 
   // Insert Phis for all conflicts
@@ -760,51 +844,67 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
   std::sort(Keys.begin(), Keys.end(), order_by_name);
   // TODO: adjust naming patterns to avoid this order of iteration dependency
   for (Value *V : Keys) {
-    Instruction *v = cast<Instruction>(V);
-    PhiState state = states[V];
-    assert(!isKnownBaseResult(v) && "why did it get added?");
-    assert(!state.isUnknown() && "Optimistic algorithm didn't complete!");
-    if (!state.isConflict())
+    Instruction *I = cast<Instruction>(V);
+    BDVState State = states[I];
+    assert(!isKnownBaseResult(I) && "why did it get added?");
+    assert(!State.isUnknown() && "Optimistic algorithm didn't complete!");
+
+    // extractelement instructions are a bit special in that we may need to
+    // insert an extract even when we know an exact base for the instruction.
+    // The problem is that we need to convert from a vector base to a scalar
+    // base for the particular indice we're interested in.
+    if (State.isBase() && isa<ExtractElementInst>(I) &&
+        isa<VectorType>(State.getBase()->getType())) {
+      auto *EE = cast<ExtractElementInst>(I);
+      // TODO: In many cases, the new instruction is just EE itself.  We should
+      // exploit this, but can't do it here since it would break the invariant
+      // about the BDV not being known to be a base.
+      auto *BaseInst = ExtractElementInst::Create(State.getBase(),
+                                                  EE->getIndexOperand(),
+                                                  "base_ee", EE);
+      BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {}));
+      states[I] = BDVState(BDVState::Base, BaseInst);
+    }
+    
+    if (!State.isConflict())
       continue;
 
-    if (isa<PHINode>(v)) {
-      int num_preds =
-          std::distance(pred_begin(v->getParent()), pred_end(v->getParent()));
-      assert(num_preds > 0 && "how did we reach here");
-      PHINode *phi = PHINode::Create(v->getType(), num_preds, "base_phi", v);
-      // Add metadata marking this as a base value
-      auto *const_1 = ConstantInt::get(
-          Type::getInt32Ty(
-              v->getParent()->getParent()->getParent()->getContext()),
-          1);
-      auto MDConst = ConstantAsMetadata::get(const_1);
-      MDNode *md = MDNode::get(
-          v->getParent()->getParent()->getParent()->getContext(), MDConst);
-      phi->setMetadata("is_base_value", md);
-      states[v] = PhiState(PhiState::Conflict, phi);
-    } else {
-      SelectInst *sel = cast<SelectInst>(v);
-      // The undef will be replaced later
-      UndefValue *undef = UndefValue::get(sel->getType());
-      SelectInst *basesel = SelectInst::Create(sel->getCondition(), undef,
-                                               undef, "base_select", sel);
-      // Add metadata marking this as a base value
-      auto *const_1 = ConstantInt::get(
-          Type::getInt32Ty(
-              v->getParent()->getParent()->getParent()->getContext()),
-          1);
-      auto MDConst = ConstantAsMetadata::get(const_1);
-      MDNode *md = MDNode::get(
-          v->getParent()->getParent()->getParent()->getContext(), MDConst);
-      basesel->setMetadata("is_base_value", md);
-      states[v] = PhiState(PhiState::Conflict, basesel);
-    }
+    /// Create and insert a new instruction which will represent the base of
+    /// the given instruction 'I'.
+    auto MakeBaseInstPlaceholder = [](Instruction *I) -> Instruction* {
+      if (isa<PHINode>(I)) {
+        BasicBlock *BB = I->getParent();
+        int NumPreds = std::distance(pred_begin(BB), pred_end(BB));
+        assert(NumPreds > 0 && "how did we reach here");
+        std::string Name = I->hasName() ?
+           (I->getName() + ".base").str() : "base_phi";
+        return PHINode::Create(I->getType(), NumPreds, Name, I);
+      } else if (SelectInst *Sel = dyn_cast<SelectInst>(I)) {
+        // The undef will be replaced later
+        UndefValue *Undef = UndefValue::get(Sel->getType());
+        std::string Name = I->hasName() ?
+          (I->getName() + ".base").str() : "base_select";
+        return SelectInst::Create(Sel->getCondition(), Undef,
+                                  Undef, Name, Sel);
+      } else {
+        auto *EE = cast<ExtractElementInst>(I);
+        UndefValue *Undef = UndefValue::get(EE->getVectorOperand()->getType());
+        std::string Name = I->hasName() ?
+          (I->getName() + ".base").str() : "base_ee";
+        return ExtractElementInst::Create(Undef, EE->getIndexOperand(), Name,
+                                          EE);
+      }
+    };
+    Instruction *BaseInst = MakeBaseInstPlaceholder(I);
+    // Add metadata marking this as a base value
+    BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {}));
+    states[I] = BDVState(BDVState::Conflict, BaseInst);
   }
 
   // Fixup all the inputs of the new PHIs
   for (auto Pair : states) {
     Instruction *v = cast<Instruction>(Pair.first);
-    PhiState state = Pair.second;
+    BDVState state = Pair.second;
 
     assert(!isKnownBaseResult(v) && "why did it get added?");
     assert(!state.isUnknown() && "Optimistic algorithm didn't complete!");
@@ -837,10 +937,10 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
             // Either conflict or base.
             assert(states.count(base));
             base = states[base].getBase();
-            assert(base != nullptr && "unknown PhiState!");
+            assert(base != nullptr && "unknown BDVState!");
           }
 
-          // In essense this assert states: the only way two
+          // In essence this assert states: the only way two
           // values incoming from the same basic block may be
           // different is by being different bitcasts of the same
           // value.  A cleanup that remains TODO is changing
@@ -860,7 +960,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
           // Either conflict or base.
           assert(states.count(base));
           base = states[base].getBase();
-          assert(base != nullptr && "unknown PhiState!");
+          assert(base != nullptr && "unknown BDVState!");
         }
         assert(base && "can't be null");
         // Must use original input BB since base may not be Instruction
@@ -872,8 +972,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
         basephi->addIncoming(base, InBB);
       }
       assert(basephi->getNumIncomingValues() == NumPHIValues);
-    } else {
-      SelectInst *basesel = cast<SelectInst>(state.getBase());
+    } else if (SelectInst *basesel = dyn_cast<SelectInst>(state.getBase())) {
       SelectInst *sel = cast<SelectInst>(v);
       // Operand 1 & 2 are true, false path respectively. TODO: refactor to
       // something more safe and less hacky.
@@ -886,7 +985,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
           // Either conflict or base.
           assert(states.count(base));
           base = states[base].getBase();
-          assert(base != nullptr && "unknown PhiState!");
+          assert(base != nullptr && "unknown BDVState!");
         }
         assert(base && "can't be null");
         // Must use original input BB since base may not be Instruction
@@ -896,6 +995,74 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
         }
         basesel->setOperand(i, base);
       }
+    } else {
+      auto *BaseEE = cast<ExtractElementInst>(state.getBase());
+      Value *InVal = cast<ExtractElementInst>(v)->getVectorOperand();
+      Value *Base = findBaseOrBDV(InVal, cache);
+      if (!isKnownBaseResult(Base)) {
+        // Either conflict or base.
+        assert(states.count(Base));
+        Base = states[Base].getBase();
+        assert(Base != nullptr && "unknown BDVState!");
+      }
+      assert(Base && "can't be null");
+      BaseEE->setOperand(0, Base);
+    }
+  }
+
+  // Now that we're done with the algorithm, see if we can optimize the 
+  // results slightly by reducing the number of new instructions needed. 
+  // Arguably, this should be integrated into the algorithm above, but 
+  // doing as a post process step is easier to reason about for the moment.
+  DenseMap<Value *, Value *> ReverseMap;
+  SmallPtrSet<Instruction *, 16> NewInsts;
+  SmallSetVector<Instruction *, 16> Worklist;
+  for (auto Item : states) {
+    Value *V = Item.first;
+    Value *Base = Item.second.getBase();
+    assert(V && Base);
+    assert(!isKnownBaseResult(V) && "why did it get added?");
+    assert(isKnownBaseResult(Base) &&
+           "must be something we 'know' is a base pointer");
+    if (!Item.second.isConflict())
+      continue;
+
+    ReverseMap[Base] = V;
+    if (auto *BaseI = dyn_cast<Instruction>(Base)) {
+      NewInsts.insert(BaseI);
+      Worklist.insert(BaseI);
+    }
+  }
+  auto PushNewUsers = [&](Instruction *I) {
+    for (User *U : I->users())
+      if (auto *UI = dyn_cast<Instruction>(U))
+        if (NewInsts.count(UI))
+          Worklist.insert(UI);
+  };
+  const DataLayout &DL = cast<Instruction>(def)->getModule()->getDataLayout();
+  while (!Worklist.empty()) {
+    Instruction *BaseI = Worklist.pop_back_val();
+    Value *Bdv = ReverseMap[BaseI];
+    if (auto *BdvI = dyn_cast<Instruction>(Bdv))
+      if (BaseI->isIdenticalTo(BdvI)) {
+        DEBUG(dbgs() << "Identical Base: " << *BaseI << "\n");
+        PushNewUsers(BaseI);
+        BaseI->replaceAllUsesWith(Bdv);
+        BaseI->eraseFromParent();
+        states[Bdv] = BDVState(BDVState::Conflict, Bdv);
+        NewInsts.erase(BaseI);
+        ReverseMap.erase(BaseI);
+        continue;
+      }
+    if (Value *V = SimplifyInstruction(BaseI, DL)) {
+      DEBUG(dbgs() << "Base " << *BaseI << " simplified to " << *V << "\n");
+      PushNewUsers(BaseI);
+      BaseI->replaceAllUsesWith(V);
+      BaseI->eraseFromParent();
+      states[Bdv] = BDVState(BDVState::Conflict, V);
+      NewInsts.erase(BaseI);
+      ReverseMap.erase(BaseI);
+      continue;
     }
   }
 
@@ -906,7 +1073,6 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
     Value *v = item.first;
     Value *base = item.second.getBase();
     assert(v && base);
-    assert(!isKnownBaseResult(v) && "why did it get added?");
 
     if (TraceLSP) {
       std::string fromstr =
@@ -918,8 +1084,6 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
              << " to: " << (base->hasName() ? base->getName() : "") << "\n";
     }
 
-    assert(isKnownBaseResult(base) &&
-           "must be something we 'know' is a base pointer");
     if (cache.count(v)) {
       // Once we transition from the BDV relation being store in the cache to
       // the base relation being stored, it must be stable
@@ -968,7 +1132,7 @@ findBasePointers(const StatepointLiveSetTy &live,
 
     // If you see this trip and like to live really dangerously, the code should
     // be correct, just with idioms the verifier can't handle.  You can try
-    // disabling the verifier at your own substaintial risk.
+    // disabling the verifier at your own substantial risk.
     assert(!isa<ConstantPointerNull>(base) &&
            "the relocation code needs adjustment to handle the relocation of "
            "a null pointer constant without causing false positives in the "
@@ -1014,7 +1178,7 @@ static void recomputeLiveInValues(
     Function &F, DominatorTree &DT, Pass *P, ArrayRef<CallSite> toUpdate,
     MutableArrayRef<struct PartiallyConstructedSafepointRecord> records) {
   // TODO-PERF: reuse the original liveness, then simply run the dataflow
-  // again.  The old values are still live and will help it stablize quickly.
+  // again.  The old values are still live and will help it stabilize quickly.
   GCPtrLivenessData RevisedLivenessData;
   computeLiveInValues(DT, F, RevisedLivenessData);
   for (size_t i = 0; i < records.size(); i++) {
@@ -1031,14 +1195,11 @@ static void recomputeLiveInValues(
 // goes through the statepoint.  We might need to split an edge to make this
 // possible.
 static BasicBlock *
-normalizeForInvokeSafepoint(BasicBlock *BB, BasicBlock *InvokeParent, Pass *P) {
-  DominatorTree *DT = nullptr;
-  if (auto *DTP = P->getAnalysisIfAvailable<DominatorTreeWrapperPass>())
-    DT = &DTP->getDomTree();
-
+normalizeForInvokeSafepoint(BasicBlock *BB, BasicBlock *InvokeParent,
+                            DominatorTree &DT) {
   BasicBlock *Ret = BB;
   if (!BB->getUniquePredecessor()) {
-    Ret = SplitBlockPredecessors(BB, InvokeParent, "", nullptr, DT);
+    Ret = SplitBlockPredecessors(BB, InvokeParent, "", &DT);
   }
 
   // Now that 'ret' has unique predecessor we can safely remove all phi nodes
@@ -1059,7 +1220,7 @@ static int find_index(ArrayRef<Value *> livevec, Value *val) {
   return index;
 }
 
-// Create new attribute set containing only attributes which can be transfered
+// Create new attribute set containing only attributes which can be transferred
 // from original call to the safepoint.
 static AttributeSet legalizeCallAttributes(AttributeSet AS) {
   AttributeSet ret;
@@ -1106,45 +1267,37 @@ static void CreateGCRelocates(ArrayRef<llvm::Value *> LiveVariables,
                               ArrayRef<llvm::Value *> BasePtrs,
                               Instruction *StatepointToken,
                               IRBuilder<> Builder) {
-  SmallVector<Instruction *, 64> NewDefs;
-  NewDefs.reserve(LiveVariables.size());
-
-  Module *M = StatepointToken->getParent()->getParent()->getParent();
+  if (LiveVariables.empty())
+    return;
+  
+  // All gc_relocate are set to i8 addrspace(1)* type. We originally generated
+  // unique declarations for each pointer type, but this proved problematic
+  // because the intrinsic mangling code is incomplete and fragile.  Since
+  // we're moving towards a single unified pointer type anyways, we can just
+  // cast everything to an i8* of the right address space.  A bitcast is added
+  // later to convert gc_relocate to the actual value's type. 
+  Module *M = StatepointToken->getModule();
+  auto AS = cast<PointerType>(LiveVariables[0]->getType())->getAddressSpace();
+  Type *Types[] = {Type::getInt8PtrTy(M->getContext(), AS)};
+  Value *GCRelocateDecl =
+    Intrinsic::getDeclaration(M, Intrinsic::experimental_gc_relocate, Types);
 
   for (unsigned i = 0; i < LiveVariables.size(); i++) {
-    // We generate a (potentially) unique declaration for every pointer type
-    // combination.  This results is some blow up the function declarations in
-    // the IR, but removes the need for argument bitcasts which shrinks the IR
-    // greatly and makes it much more readable.
-    SmallVector<Type *, 1> Types;                 // one per 'any' type
-    // All gc_relocate are set to i8 addrspace(1)* type. This could help avoid
-    // cases where the actual value's type mangling is not supported by llvm. A
-    // bitcast is added later to convert gc_relocate to the actual value's type.
-    Types.push_back(Type::getInt8PtrTy(M->getContext(), 1));
-    Value *GCRelocateDecl = Intrinsic::getDeclaration(
-        M, Intrinsic::experimental_gc_relocate, Types);
-
     // Generate the gc.relocate call and save the result
     Value *BaseIdx =
-        ConstantInt::get(Type::getInt32Ty(M->getContext()),
-                         LiveStart + find_index(LiveVariables, BasePtrs[i]));
-    Value *LiveIdx = ConstantInt::get(
-        Type::getInt32Ty(M->getContext()),
-        LiveStart + find_index(LiveVariables, LiveVariables[i]));
+      Builder.getInt32(LiveStart + find_index(LiveVariables, BasePtrs[i]));
+    Value *LiveIdx =
+      Builder.getInt32(LiveStart + find_index(LiveVariables, LiveVariables[i]));
 
     // only specify a debug name if we can give a useful one
-    Value *Reloc = Builder.CreateCall(
+    CallInst *Reloc = Builder.CreateCall(
         GCRelocateDecl, {StatepointToken, BaseIdx, LiveIdx},
         LiveVariables[i]->hasName() ? LiveVariables[i]->getName() + ".relocated"
                                     : "");
     // Trick CodeGen into thinking there are lots of free registers at this
     // fake call.
-    cast<CallInst>(Reloc)->setCallingConv(CallingConv::Cold);
-
-    NewDefs.push_back(cast<Instruction>(Reloc));
+    Reloc->setCallingConv(CallingConv::Cold);
   }
-  assert(NewDefs.size() == LiveVariables.size() &&
-         "missing or extra redefinition at safepoint");
 }
 
 static void
@@ -1199,7 +1352,7 @@ makeStatepointExplicitImpl(const CallSite &CS, /* to replace */
     // Currently we will fail on parameter attributes and on certain
     // function attributes.
     AttributeSet new_attrs = legalizeCallAttributes(toReplace->getAttributes());
-    // In case if we can handle this set of sttributes - set up function attrs
+    // In case if we can handle this set of attributes - set up function attrs
     // directly on statepoint and return attrs later for gc_result intrinsic.
     call->setAttributes(new_attrs.getFnAttributes());
     return_attributes = new_attrs.getRetAttributes();
@@ -1223,13 +1376,13 @@ makeStatepointExplicitImpl(const CallSite &CS, /* to replace */
     // original block.
     InvokeInst *invoke = InvokeInst::Create(
         gc_statepoint_decl, toReplace->getNormalDest(),
-        toReplace->getUnwindDest(), args, "", toReplace->getParent());
+        toReplace->getUnwindDest(), args, "statepoint_token", toReplace->getParent());
     invoke->setCallingConv(toReplace->getCallingConv());
 
     // Currently we will fail on parameter attributes and on certain
     // function attributes.
     AttributeSet new_attrs = legalizeCallAttributes(toReplace->getAttributes());
-    // In case if we can handle this set of sttributes - set up function attrs
+    // In case if we can handle this set of attributes - set up function attrs
     // directly on statepoint and return attrs later for gc_result intrinsic.
     invoke->setAttributes(new_attrs.getFnAttributes());
     return_attributes = new_attrs.getRetAttributes();
@@ -1254,10 +1407,8 @@ makeStatepointExplicitImpl(const CallSite &CS, /* to replace */
             unwindBlock->getLandingPadInst(), idx, "relocate_token"));
     result.UnwindToken = exceptional_token;
 
-    // Just throw away return value. We will use the one we got for normal
-    // block.
-    (void)CreateGCRelocates(liveVariables, live_start, basePtrs,
-                            exceptional_token, Builder);
+    CreateGCRelocates(liveVariables, live_start, basePtrs,
+                      exceptional_token, Builder);
 
     // Generate gc relocates and returns for normal block
     BasicBlock *normalDest = toReplace->getNormalDest();
@@ -1340,8 +1491,7 @@ makeStatepointExplicit(DominatorTree &DT, const CallSite &CS, Pass *P,
   basevec.reserve(liveset.size());
   for (Value *L : liveset) {
     livevec.push_back(L);
-
-    assert(PointerToBase.find(L) != PointerToBase.end());
+    assert(PointerToBase.count(L));
     Value *base = PointerToBase[L];
     basevec.push_back(base);
   }
@@ -1510,7 +1660,7 @@ static void relocationViaAlloca(
                                   VisitedLiveValues);
 
     if (ClobberNonLive) {
-      // As a debuging aid, pretend that an unrelocated pointer becomes null at
+      // As a debugging aid, pretend that an unrelocated pointer becomes null at
       // the gc.statepoint.  This will turn some subtle GC problems into
       // slightly easier to debug SEGVs.  Note that on large IR files with
       // lots of gc.statepoints this is extremely costly both memory and time
@@ -1633,17 +1783,10 @@ static void relocationViaAlloca(
 /// vector.  Doing so has the effect of changing the output of a couple of
 /// tests in ways which make them less useful in testing fused safepoints.
 template <typename T> static void unique_unsorted(SmallVectorImpl<T> &Vec) {
-  DenseSet<T> Seen;
-  SmallVector<T, 128> TempVec;
-  TempVec.reserve(Vec.size());
-  for (auto Element : Vec)
-    TempVec.push_back(Element);
-  Vec.clear();
-  for (auto V : TempVec) {
-    if (Seen.insert(V).second) {
-      Vec.push_back(V);
-    }
-  }
+  SmallSet<T, 8> Seen;
+  Vec.erase(std::remove_if(Vec.begin(), Vec.end(), [&](const T &V) {
+              return !Seen.insert(V).second;
+            }), Vec.end());
 }
 
 /// Insert holders so that each Value is obviously live through the entire
@@ -1688,12 +1831,14 @@ static void findLiveReferences(
 
 /// Remove any vector of pointers from the liveset by scalarizing them over the
 /// statepoint instruction.  Adds the scalarized pieces to the liveset.  It
-/// would be preferrable to include the vector in the statepoint itself, but
+/// would be preferable to include the vector in the statepoint itself, but
 /// the lowering code currently does not handle that.  Extending it would be
 /// slightly non-trivial since it requires a format change.  Given how rare
-/// such cases are (for the moment?) scalarizing is an acceptable comprimise.
+/// such cases are (for the moment?) scalarizing is an acceptable compromise.
 static void splitVectorValues(Instruction *StatepointInst,
-                              StatepointLiveSetTy &LiveSet, DominatorTree &DT) {
+                              StatepointLiveSetTy &LiveSet,
+                              DenseMap<Value *, Value *>& PointerToBase,
+                              DominatorTree &DT) {
   SmallVector<Value *, 16> ToSplit;
   for (Value *V : LiveSet)
     if (isa<VectorType>(V->getType()))
@@ -1702,14 +1847,14 @@ static void splitVectorValues(Instruction *StatepointInst,
   if (ToSplit.empty())
     return;
 
+  DenseMap<Value *, SmallVector<Value *, 16>> ElementMapping;
+
   Function &F = *(StatepointInst->getParent()->getParent());
 
   DenseMap<Value *, AllocaInst *> AllocaMap;
   // First is normal return, second is exceptional return (invoke only)
   DenseMap<Value *, std::pair<Value *, Value *>> Replacements;
   for (Value *V : ToSplit) {
-    LiveSet.erase(V);
-
     AllocaInst *Alloca =
         new AllocaInst(V->getType(), "", F.getEntryBlock().getFirstNonPHI());
     AllocaMap[V] = Alloca;
@@ -1719,7 +1864,7 @@ static void splitVectorValues(Instruction *StatepointInst,
     SmallVector<Value *, 16> Elements;
     for (unsigned i = 0; i < VT->getNumElements(); i++)
       Elements.push_back(Builder.CreateExtractElement(V, Builder.getInt32(i)));
-    LiveSet.insert(Elements.begin(), Elements.end());
+    ElementMapping[V] = Elements;
 
     auto InsertVectorReform = [&](Instruction *IP) {
       Builder.SetInsertPoint(IP);
@@ -1752,6 +1897,7 @@ static void splitVectorValues(Instruction *StatepointInst,
       Replacements[V].second = InsertVectorReform(IP);
     }
   }
+
   for (Value *V : ToSplit) {
     AllocaInst *Alloca = AllocaMap[V];
 
@@ -1795,12 +1941,31 @@ static void splitVectorValues(Instruction *StatepointInst,
   for (Value *V : ToSplit)
     Allocas.push_back(AllocaMap[V]);
   PromoteMemToReg(Allocas, DT);
+
+  // Update our tracking of live pointers and base mappings to account for the
+  // changes we just made.
+  for (Value *V : ToSplit) {
+    auto &Elements = ElementMapping[V];
+
+    LiveSet.erase(V);
+    LiveSet.insert(Elements.begin(), Elements.end());
+    // We need to update the base mapping as well.
+    assert(PointerToBase.count(V));
+    Value *OldBase = PointerToBase[V];
+    auto &BaseElements = ElementMapping[OldBase];
+    PointerToBase.erase(V);
+    assert(Elements.size() == BaseElements.size());
+    for (unsigned i = 0; i < Elements.size(); i++) {
+      Value *Elem = Elements[i];
+      PointerToBase[Elem] = BaseElements[i];
+    }
+  }
 }
 
 // Helper function for the "rematerializeLiveValues". It walks use chain
 // starting from the "CurrentValue" until it meets "BaseValue". Only "simple"
 // values are visited (currently it is GEP's and casts). Returns true if it
-// sucessfully reached "BaseValue" and false otherwise.
+// successfully reached "BaseValue" and false otherwise.
 // Fills "ChainToBase" array with all visited values. "BaseValue" is not
 // recorded.
 static bool findRematerializableChainToBasePointer(
@@ -1877,8 +2042,8 @@ chainToBasePointerCost(SmallVectorImpl<Instruction*> &Chain,
 static void rematerializeLiveValues(CallSite CS,
                                     PartiallyConstructedSafepointRecord &Info,
                                     TargetTransformInfo &TTI) {
-  const int ChainLengthThreshold = 10;
-  
+  const unsigned int ChainLengthThreshold = 10;
+
   // Record values we are going to delete from this statepoint live set.
   // We can not di this in following loop due to iterator invalidation.
   SmallVector<Value *, 32> LiveValuesToBeDeleted;
@@ -1886,7 +2051,7 @@ static void rematerializeLiveValues(CallSite CS,
   for (Value *LiveValue: Info.liveset) {
     // For each live pointer find it's defining chain
     SmallVector<Instruction *, 3> ChainToBase;
-    assert(Info.PointerToBase.find(LiveValue) != Info.PointerToBase.end());
+    assert(Info.PointerToBase.count(LiveValue));
     bool FoundChain =
       findRematerializableChainToBasePointer(ChainToBase,
                                              LiveValue,
@@ -1944,12 +2109,12 @@ static void rematerializeLiveValues(CallSite CS,
           assert(LastValue);
           ClonedValue->replaceUsesOfWith(LastValue, LastClonedValue);
 #ifndef NDEBUG
-          // Assert that cloned instruction does not use any instructions
-          // other than LastClonedValue
-          for (auto OpValue: ClonedValue->operand_values()) {
-            if (isa<Instruction>(OpValue))
-              assert(OpValue == LastClonedValue &&
-                     "unexpected use found in rematerialized value");
+          // Assert that cloned instruction does not use any instructions from
+          // this chain other than LastClonedValue
+          for (auto OpValue : ClonedValue->operand_values()) {
+            assert(std::find(ChainToBase.begin(), ChainToBase.end(), OpValue) ==
+                       ChainToBase.end() &&
+                   "incorrect use in rematerialization chain");
           }
 #endif
         }
@@ -2016,9 +2181,9 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, Pass *P,
       continue;
     InvokeInst *invoke = cast<InvokeInst>(CS.getInstruction());
     normalizeForInvokeSafepoint(invoke->getNormalDest(), invoke->getParent(),
-                                P);
+                                DT);
     normalizeForInvokeSafepoint(invoke->getUnwindDest(), invoke->getParent(),
-                                P);
+                                DT);
   }
 
   // A list of dummy calls added to the IR to keep various values obviously
@@ -2052,21 +2217,10 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, Pass *P,
   }
   assert(records.size() == toUpdate.size());
 
-  // A) Identify all gc pointers which are staticly live at the given call
+  // A) Identify all gc pointers which are statically live at the given call
   // site.
   findLiveReferences(F, DT, P, toUpdate, records);
 
-  // Do a limited scalarization of any live at safepoint vector values which
-  // contain pointers.  This enables this pass to run after vectorization at
-  // the cost of some possible performance loss.  TODO: it would be nice to
-  // natively support vectors all the way through the backend so we don't need
-  // to scalarize here.
-  for (size_t i = 0; i < records.size(); i++) {
-    struct PartiallyConstructedSafepointRecord &info = records[i];
-    Instruction *statepoint = toUpdate[i].getInstruction();
-    splitVectorValues(cast<Instruction>(statepoint), info.liveset, DT);
-  }
-
   // B) Find the base pointers for each live pointer
   /* scope for caching */ {
     // Cache the 'defining value' relation used in the computation and
@@ -2127,13 +2281,25 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, Pass *P,
   }
   holders.clear();
 
+  // Do a limited scalarization of any live at safepoint vector values which
+  // contain pointers.  This enables this pass to run after vectorization at
+  // the cost of some possible performance loss.  TODO: it would be nice to
+  // natively support vectors all the way through the backend so we don't need
+  // to scalarize here.
+  for (size_t i = 0; i < records.size(); i++) {
+    struct PartiallyConstructedSafepointRecord &info = records[i];
+    Instruction *statepoint = toUpdate[i].getInstruction();
+    splitVectorValues(cast<Instruction>(statepoint), info.liveset,
+                      info.PointerToBase, DT);
+  }
+
   // In order to reduce live set of statepoint we might choose to rematerialize
-  // some values instead of relocating them. This is purelly an optimization and
+  // some values instead of relocating them. This is purely an optimization and
   // does not influence correctness.
   TargetTransformInfo &TTI =
     P->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
 
-  for (size_t i = 0; i < records.size(); i++) { 
+  for (size_t i = 0; i < records.size(); i++) {
     struct PartiallyConstructedSafepointRecord &info = records[i];
     CallSite &CS = toUpdate[i];
 
@@ -2197,21 +2363,99 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, Pass *P,
   return !records.empty();
 }
 
+// Handles both return values and arguments for Functions and CallSites.
+template <typename AttrHolder>
+static void RemoveDerefAttrAtIndex(LLVMContext &Ctx, AttrHolder &AH,
+                                   unsigned Index) {
+  AttrBuilder R;
+  if (AH.getDereferenceableBytes(Index))
+    R.addAttribute(Attribute::get(Ctx, Attribute::Dereferenceable,
+                                  AH.getDereferenceableBytes(Index)));
+  if (AH.getDereferenceableOrNullBytes(Index))
+    R.addAttribute(Attribute::get(Ctx, Attribute::DereferenceableOrNull,
+                                  AH.getDereferenceableOrNullBytes(Index)));
+
+  if (!R.empty())
+    AH.setAttributes(AH.getAttributes().removeAttributes(
+        Ctx, Index, AttributeSet::get(Ctx, Index, R)));
+}
+
+void
+RewriteStatepointsForGC::stripDereferenceabilityInfoFromPrototype(Function &F) {
+  LLVMContext &Ctx = F.getContext();
+
+  for (Argument &A : F.args())
+    if (isa<PointerType>(A.getType()))
+      RemoveDerefAttrAtIndex(Ctx, F, A.getArgNo() + 1);
+
+  if (isa<PointerType>(F.getReturnType()))
+    RemoveDerefAttrAtIndex(Ctx, F, AttributeSet::ReturnIndex);
+}
+
+void RewriteStatepointsForGC::stripDereferenceabilityInfoFromBody(Function &F) {
+  if (F.empty())
+    return;
+
+  LLVMContext &Ctx = F.getContext();
+  MDBuilder Builder(Ctx);
+
+  for (Instruction &I : instructions(F)) {
+    if (const MDNode *MD = I.getMetadata(LLVMContext::MD_tbaa)) {
+      assert(MD->getNumOperands() < 5 && "unrecognized metadata shape!");
+      bool IsImmutableTBAA =
+          MD->getNumOperands() == 4 &&
+          mdconst::extract<ConstantInt>(MD->getOperand(3))->getValue() == 1;
+
+      if (!IsImmutableTBAA)
+        continue; // no work to do, MD_tbaa is already marked mutable
+
+      MDNode *Base = cast<MDNode>(MD->getOperand(0));
+      MDNode *Access = cast<MDNode>(MD->getOperand(1));
+      uint64_t Offset =
+          mdconst::extract<ConstantInt>(MD->getOperand(2))->getZExtValue();
+
+      MDNode *MutableTBAA =
+          Builder.createTBAAStructTagNode(Base, Access, Offset);
+      I.setMetadata(LLVMContext::MD_tbaa, MutableTBAA);
+    }
+
+    if (CallSite CS = CallSite(&I)) {
+      for (int i = 0, e = CS.arg_size(); i != e; i++)
+        if (isa<PointerType>(CS.getArgument(i)->getType()))
+          RemoveDerefAttrAtIndex(Ctx, CS, i + 1);
+      if (isa<PointerType>(CS.getType()))
+        RemoveDerefAttrAtIndex(Ctx, CS, AttributeSet::ReturnIndex);
+    }
+  }
+}
+
 /// Returns true if this function should be rewritten by this pass.  The main
 /// point of this function is as an extension point for custom logic.
 static bool shouldRewriteStatepointsIn(Function &F) {
   // TODO: This should check the GCStrategy
   if (F.hasGC()) {
-    const char *FunctionGCName = F.getGC();\r
-    const StringRef StatepointExampleName("statepoint-example");\r
-    const StringRef CoreCLRName("coreclr");\r
-    return (StatepointExampleName == FunctionGCName) ||\r
-      (CoreCLRName == FunctionGCName);
-  }
-  else
+    const char *FunctionGCName = F.getGC();
+    const StringRef StatepointExampleName("statepoint-example");
+    const StringRef CoreCLRName("coreclr");
+    return (StatepointExampleName == FunctionGCName) ||
+           (CoreCLRName == FunctionGCName);
+  } else
     return false;
 }
 
+void RewriteStatepointsForGC::stripDereferenceabilityInfo(Module &M) {
+#ifndef NDEBUG
+  assert(std::any_of(M.begin(), M.end(), shouldRewriteStatepointsIn) &&
+         "precondition!");
+#endif
+
+  for (Function &F : M)
+    stripDereferenceabilityInfoFromPrototype(F);
+
+  for (Function &F : M)
+    stripDereferenceabilityInfoFromBody(F);
+}
+
 bool RewriteStatepointsForGC::runOnFunction(Function &F) {
   // Nothing to do for declarations.
   if (F.isDeclaration() || F.empty())
@@ -2222,14 +2466,14 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F) {
   if (!shouldRewriteStatepointsIn(F))
     return false;
 
-  DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+  DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
 
   // Gather all the statepoints which need rewritten.  Be careful to only
   // consider those in reachable code since we need to ask dominance queries
   // when rewriting.  We'll delete the unreachable ones in a moment.
   SmallVector<CallSite, 64> ParsePointNeeded;
   bool HasUnreachableStatepoint = false;
-  for (Instruction &I : inst_range(F)) {
+  for (Instruction &I : instructions(F)) {
     // TODO: only the ones with the flag set!
     if (isStatepoint(I)) {
       if (DT.isReachableFromEntry(I.getParent()))
@@ -2262,6 +2506,37 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F) {
       FoldSingleEntryPHINodes(&BB);
     }
 
+  // Before we start introducing relocations, we want to tweak the IR a bit to
+  // avoid unfortunate code generation effects.  The main example is that we 
+  // want to try to make sure the comparison feeding a branch is after any
+  // safepoints.  Otherwise, we end up with a comparison of pre-relocation
+  // values feeding a branch after relocation.  This is semantically correct,
+  // but results in extra register pressure since both the pre-relocation and
+  // post-relocation copies must be available in registers.  For code without
+  // relocations this is handled elsewhere, but teaching the scheduler to
+  // reverse the transform we're about to do would be slightly complex.
+  // Note: This may extend the live range of the inputs to the icmp and thus
+  // increase the liveset of any statepoint we move over.  This is profitable
+  // as long as all statepoints are in rare blocks.  If we had in-register
+  // lowering for live values this would be a much safer transform.
+  auto getConditionInst = [](TerminatorInst *TI) -> Instruction* {
+    if (auto *BI = dyn_cast<BranchInst>(TI))
+      if (BI->isConditional())
+        return dyn_cast<Instruction>(BI->getCondition());
+    // TODO: Extend this to handle switches
+    return nullptr;
+  };
+  for (BasicBlock &BB : F) {
+    TerminatorInst *TI = BB.getTerminator();
+    if (auto *Cond = getConditionInst(TI))
+      // TODO: Handle more than just ICmps here.  We should be able to move
+      // most instructions without side effects or memory access.  
+      if (isa<ICmpInst>(Cond) && Cond->hasOneUse()) {
+        MadeChange = true;
+        Cond->moveBefore(TI);
+      }
+  }
+
   MadeChange |= insertParsePoints(F, DT, this, ParsePointNeeded);
   return MadeChange;
 }
@@ -2295,7 +2570,7 @@ static void computeLiveInValues(BasicBlock::reverse_iterator rbegin,
              "support for FCA unimplemented");
       if (isHandledGCPointerType(V->getType()) && !isa<Constant>(V)) {
         // The choice to exclude all things constant here is slightly subtle.
-        // There are two idependent reasons:
+        // There are two independent reasons:
         // - We assume that things which are constant (from LLVM's definition)
         // do not move at runtime.  For example, the address of a global
         // variable is fixed, even though it's contents may not be.
@@ -2433,7 +2708,7 @@ static void computeLiveInValues(DominatorTree &DT, Function &F,
   } // while( !worklist.empty() )
 
 #ifndef NDEBUG
-  // Sanity check our ouput against SSA properties.  This helps catch any
+  // Sanity check our output against SSA properties.  This helps catch any
   // missing kills during the above iteration.
   for (BasicBlock &BB : F) {
     checkBasicSSA(DT, Data, BB);