[SCEV][LV] Add SCEV Predicates and use them to re-implement stride versioning
authorSilviu Baranga <silviu.baranga@arm.com>
Mon, 2 Nov 2015 14:41:02 +0000 (14:41 +0000)
committerSilviu Baranga <silviu.baranga@arm.com>
Mon, 2 Nov 2015 14:41:02 +0000 (14:41 +0000)
Summary:
SCEV Predicates represent conditions that typically cannot be derived from
static analysis, but can be used to reduce SCEV expressions to forms which are
usable for different optimizers.

ScalarEvolution now has the rewriteUsingPredicate method which can simplify a
SCEV expression using a SCEVPredicateSet. The normal workflow of a pass using
SCEVPredicates would be to hold a SCEVPredicateSet and every time assumptions
need to be made a new SCEV Predicate would be created and added to the set.
Each time after calling getSCEV, the user will call the rewriteUsingPredicate
method.

We add two types of predicates
SCEVPredicateSet - implements a set of predicates
SCEVEqualPredicate - tests for equality between two SCEV expressions

We use the SCEVEqualPredicate to re-implement stride versioning. Every time we
version a stride, we will add a SCEVEqualPredicate to the context.
Instead of adding specific stride checks, LoopVectorize now adds a more
generic SCEV check.

We only need to add support for this in the LoopVectorizer since this is the
only pass that will do stride versioning.

Reviewers: mzolotukhin, anemet, hfinkel, sanjoy

Subscribers: sanjoy, hfinkel, rengolin, jmolloy, llvm-commits

Differential Revision: http://reviews.llvm.org/D13595

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@251800 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/LoopAccessAnalysis.h
include/llvm/Analysis/ScalarEvolution.h
include/llvm/Analysis/ScalarEvolutionExpander.h
lib/Analysis/LoopAccessAnalysis.cpp
lib/Analysis/ScalarEvolution.cpp
lib/Analysis/ScalarEvolutionExpander.cpp
lib/Transforms/Vectorize/LoopVectorize.cpp

index 89ed952ee81de0af5f8d8d5624299a8fa5e33650..97a72b3d1cd15aa3678efbf2199ab61035c5b1c5 100644 (file)
@@ -32,6 +32,7 @@ class DataLayout;
 class ScalarEvolution;
 class Loop;
 class SCEV;
+class SCEVUnionPredicate;
 
 /// Optimization analysis message produced during vectorization. Messages inform
 /// the user why vectorization did not occur.
@@ -176,10 +177,11 @@ public:
                const SmallVectorImpl<Instruction *> &Instrs) const;
   };
 
-  MemoryDepChecker(ScalarEvolution *Se, const Loop *L)
+  MemoryDepChecker(ScalarEvolution *Se, const Loop *L,
+                   SCEVUnionPredicate &Preds)
       : SE(Se), InnermostLoop(L), AccessIdx(0),
         ShouldRetryWithRuntimeCheck(false), SafeForVectorization(true),
-        RecordInterestingDependences(true) {}
+        RecordInterestingDependences(true), Preds(Preds) {}
 
   /// \brief Register the location (instructions are given increasing numbers)
   /// of a write access.
@@ -289,6 +291,15 @@ private:
   /// \brief Check whether the data dependence could prevent store-load
   /// forwarding.
   bool couldPreventStoreLoadForward(unsigned Distance, unsigned TypeByteSize);
+
+  /// The SCEV predicate containing all the SCEV-related assumptions.
+  /// The dependence checker needs this in order to convert SCEVs of pointers
+  /// to more accurate expressions in the context of existing assumptions.
+  /// We also need this in case assumptions about SCEV expressions need to
+  /// be made in order to avoid unknown dependences. For example we might
+  /// assume a unit stride for a pointer in order to prove that a memory access
+  /// is strided and doesn't wrap.
+  SCEVUnionPredicate &Preds;
 };
 
 /// \brief Holds information about the memory runtime legality checks to verify
@@ -330,8 +341,13 @@ public:
   }
 
   /// Insert a pointer and calculate the start and end SCEVs.
+  /// \p We need Preds in order to compute the SCEV expression of the pointer
+  /// according to the assumptions that we've made during the analysis.
+  /// The method might also version the pointer stride according to \p Strides,
+  /// and change \p Preds.
   void insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId,
-              unsigned ASId, const ValueToValueMap &Strides);
+              unsigned ASId, const ValueToValueMap &Strides,
+              SCEVUnionPredicate &Preds);
 
   /// \brief No run-time memory checking is necessary.
   bool empty() const { return Pointers.empty(); }
@@ -537,6 +553,15 @@ public:
     return StoreToLoopInvariantAddress;
   }
 
+  /// The SCEV predicate contains all the SCEV-related assumptions.
+  /// The is used to keep track of the minimal set of assumptions on SCEV
+  /// expressions that the analysis needs to make in order to return a
+  /// meaningful result. All SCEV expressions during the analysis should be
+  /// re-written (and therefore simplified) according to Preds.
+  /// A user of LoopAccessAnalysis will need to emit the runtime checks
+  /// associated with this predicate.
+  SCEVUnionPredicate Preds;
+
 private:
   /// \brief Analyze the loop.  Substitute symbolic strides using Strides.
   void analyzeLoop(const ValueToValueMap &Strides);
@@ -583,19 +608,26 @@ private:
 Value *stripIntegerCast(Value *V);
 
 ///\brief Return the SCEV corresponding to a pointer with the symbolic stride
-///replaced with constant one.
+/// replaced with constant one, assuming \p Preds is true.
+///
+/// If necessary this method will version the stride of the pointer according
+/// to \p PtrToStride and therefore add a new predicate to \p Preds.
 ///
 /// If \p OrigPtr is not null, use it to look up the stride value instead of \p
 /// Ptr.  \p PtrToStride provides the mapping between the pointer value and its
 /// stride as collected by LoopVectorizationLegality::collectStridedAccess.
 const SCEV *replaceSymbolicStrideSCEV(ScalarEvolution *SE,
                                       const ValueToValueMap &PtrToStride,
-                                      Value *Ptr, Value *OrigPtr = nullptr);
+                                      SCEVUnionPredicate &Preds, Value *Ptr,
+                                      Value *OrigPtr = nullptr);
 
 /// \brief Check the stride of the pointer and ensure that it does not wrap in
-/// the address space.
+/// the address space, assuming \p Preds is true.
+///
+/// If necessary this method will version the stride of the pointer according
+/// to \p PtrToStride and therefore add a new predicate to \p Preds.
 int isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp,
-                 const ValueToValueMap &StridesMap);
+                 const ValueToValueMap &StridesMap, SCEVUnionPredicate &Preds);
 
 /// \brief This analysis provides dependence information for the memory accesses
 /// of a loop.
index fa2a0555920ffa778aea5d1558e4715ed66307b8..fd380d9256861da4edad094dbd1540e475dabab4 100644 (file)
@@ -48,10 +48,15 @@ namespace llvm {
   class Loop;
   class LoopInfo;
   class Operator;
-  class SCEVUnknown;
-  class SCEVAddRecExpr;
   class SCEV;
-  template<> struct FoldingSetTrait<SCEV>;
+  class SCEVAddRecExpr;
+  class SCEVConstant;
+  class SCEVExpander;
+  class SCEVPredicate;
+  class SCEVUnknown;
+
+  template <> struct FoldingSetTrait<SCEV>;
+  template <> struct FoldingSetTrait<SCEVPredicate>;
 
   /// This class represents an analyzed expression in the program.  These are
   /// opaque objects that the client is not allowed to do much with directly.
@@ -164,6 +169,148 @@ namespace llvm {
     static bool classof(const SCEV *S);
   };
 
+  /// SCEVPredicate - This class represents an assumption made using SCEV
+  /// expressions which can be checked at run-time.
+  class SCEVPredicate : public FoldingSetNode {
+    friend struct FoldingSetTrait<SCEVPredicate>;
+
+    /// A reference to an Interned FoldingSetNodeID for this node.  The
+    /// ScalarEvolution's BumpPtrAllocator holds the data.
+    FoldingSetNodeIDRef FastID;
+
+  public:
+    enum SCEVPredicateKind { P_Union, P_Equal };
+
+  protected:
+    SCEVPredicateKind Kind;
+
+  public:
+    SCEVPredicate(const FoldingSetNodeIDRef ID, SCEVPredicateKind Kind);
+
+    virtual ~SCEVPredicate() {}
+
+    SCEVPredicateKind getKind() const { return Kind; }
+
+    /// \brief Returns the estimated complexity of this predicate.
+    /// This is roughly measured in the number of run-time checks required.
+    virtual unsigned getComplexity() { return 1; }
+
+    /// \brief Returns true if the predicate is always true. This means that no
+    /// assumptions were made and nothing needs to be checked at run-time.
+    virtual bool isAlwaysTrue() const = 0;
+
+    /// \brief Returns true if this predicate implies \p N.
+    virtual bool implies(const SCEVPredicate *N) const = 0;
+
+    /// \brief Prints a textual representation of this predicate with an
+    /// indentation of \p Depth.
+    virtual void print(raw_ostream &OS, unsigned Depth = 0) const = 0;
+
+    /// \brief Returns the SCEV to which this predicate applies, or nullptr
+    /// if this is a SCEVUnionPredicate.
+    virtual const SCEV *getExpr() const = 0;
+  };
+
+  inline raw_ostream &operator<<(raw_ostream &OS, const SCEVPredicate &P) {
+    P.print(OS);
+    return OS;
+  }
+
+  // Specialize FoldingSetTrait for SCEVPredicate to avoid needing to compute
+  // temporary FoldingSetNodeID values.
+  template <>
+  struct FoldingSetTrait<SCEVPredicate>
+      : DefaultFoldingSetTrait<SCEVPredicate> {
+
+    static void Profile(const SCEVPredicate &X, FoldingSetNodeID &ID) {
+      ID = X.FastID;
+    }
+
+    static bool Equals(const SCEVPredicate &X, const FoldingSetNodeID &ID,
+                       unsigned IDHash, FoldingSetNodeID &TempID) {
+      return ID == X.FastID;
+    }
+    static unsigned ComputeHash(const SCEVPredicate &X,
+                                FoldingSetNodeID &TempID) {
+      return X.FastID.ComputeHash();
+    }
+  };
+
+  /// SCEVEqualPredicate - This class represents an assumption that two SCEV
+  /// expressions are equal, and this can be checked at run-time. We assume
+  /// that the left hand side is a SCEVUnknown and the right hand side a
+  /// constant.
+  class SCEVEqualPredicate : public SCEVPredicate {
+    /// We assume that LHS == RHS, where LHS is a SCEVUnknown and RHS a
+    /// constant.
+    const SCEVUnknown *LHS;
+    const SCEVConstant *RHS;
+
+  public:
+    SCEVEqualPredicate(const FoldingSetNodeIDRef ID, const SCEVUnknown *LHS,
+                       const SCEVConstant *RHS);
+
+    /// Implementation of the SCEVPredicate interface
+    bool implies(const SCEVPredicate *N) const override;
+    void print(raw_ostream &OS, unsigned Depth = 0) const override;
+    bool isAlwaysTrue() const override;
+    const SCEV *getExpr() const;
+
+    /// \brief Returns the left hand side of the equality.
+    const SCEVUnknown *getLHS() const { return LHS; }
+
+    /// \brief Returns the right hand side of the equality.
+    const SCEVConstant *getRHS() const { return RHS; }
+
+    /// Methods for support type inquiry through isa, cast, and dyn_cast:
+    static inline bool classof(const SCEVPredicate *P) {
+      return P->getKind() == P_Equal;
+    }
+  };
+
+  /// SCEVUnionPredicate - This class represents a composition of other
+  /// SCEV predicates, and is the class that most clients will interact with.
+  /// This is equivalent to a logical "AND" of all the predicates in the union.
+  class SCEVUnionPredicate : public SCEVPredicate {
+  private:
+    typedef DenseMap<const SCEV *, SmallVector<const SCEVPredicate *, 4>>
+        PredicateMap;
+
+    /// Vector with references to all predicates in this union.
+    SmallVector<const SCEVPredicate *, 16> Preds;
+    /// Maps SCEVs to predicates for quick look-ups.
+    PredicateMap SCEVToPreds;
+
+  public:
+    SCEVUnionPredicate();
+
+    const SmallVectorImpl<const SCEVPredicate *> &getPredicates() const {
+      return Preds;
+    }
+
+    /// \brief Adds a predicate to this union.
+    void add(const SCEVPredicate *N);
+
+    /// \brief Returns a reference to a vector containing all predicates
+    /// which apply to \p Expr.
+    ArrayRef<const SCEVPredicate *> getPredicatesForExpr(const SCEV *Expr);
+
+    /// Implementation of the SCEVPredicate interface
+    bool isAlwaysTrue() const override;
+    bool implies(const SCEVPredicate *N) const override;
+    void print(raw_ostream &OS, unsigned Depth) const;
+    const SCEV *getExpr() const override;
+
+    /// \brief We estimate the complexity of a union predicate as the size
+    /// number of predicates in the union.
+    unsigned getComplexity() override { return Preds.size(); }
+
+    /// Methods for support type inquiry through isa, cast, and dyn_cast:
+    static inline bool classof(const SCEVPredicate *P) {
+      return P->getKind() == P_Union;
+    }
+  };
+
   /// The main scalar evolution driver. Because client code (intentionally)
   /// can't do much with the SCEV objects directly, they must ask this class
   /// for services.
@@ -1108,6 +1255,12 @@ namespace llvm {
       return F.getParent()->getDataLayout();
     }
 
+    const SCEVPredicate *getEqualPredicate(const SCEVUnknown *LHS,
+                                           const SCEVConstant *RHS);
+
+    /// Re-writes the SCEV according to the Predicates in \p Preds.
+    const SCEV *rewriteUsingPredicate(const SCEV *Scev, SCEVUnionPredicate &A);
+
   private:
     /// Compute the backedge taken count knowing the interval difference, the
     /// stride and presence of the equality in the comparison.
@@ -1128,6 +1281,7 @@ namespace llvm {
 
   private:
     FoldingSet<SCEV> UniqueSCEVs;
+    FoldingSet<SCEVPredicate> UniquePreds;
     BumpPtrAllocator SCEVAllocator;
 
     /// The head of a linked list of all SCEVUnknown values that have been
index 6bddb8a8132184f61a5ce5825856a88d67c3c89b..b9939168a99d33fd17dc943ac735761e2bbcab4e 100644 (file)
@@ -151,6 +151,22 @@ namespace llvm {
     /// block.
     Value *expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I);
 
+    /// \brief Generates a code sequence that evaluates this predicate.
+    /// The inserted instructions will be at position \p Loc.
+    /// The result will be of type i1 and will have a value of 0 when the
+    /// predicate is false and 1 otherwise.
+    Value *expandCodeForPredicate(const SCEVPredicate *Pred, Instruction *Loc);
+
+    /// \brief A specialized variant of expandCodeForPredicate, handling the
+    /// case when we are expanding code for a SCEVEqualPredicate.
+    Value *expandEqualPredicate(const SCEVEqualPredicate *Pred,
+                                Instruction *Loc);
+
+    /// \brief A specialized variant of expandCodeForPredicate, handling the
+    /// case when we are expanding code for a SCEVUnionPredicate.
+    Value *expandUnionPredicate(const SCEVUnionPredicate *Pred,
+                                Instruction *Loc);
+
     /// \brief Set the current IV increment loop and position.
     void setIVIncInsertPos(const Loop *L, Instruction *Pos) {
       assert(!CanonicalMode &&
index 770e0e2ec060b89ec0565c77f2c8a7deabe380e0..25025db0527ad0dc1deca5261a9588c2c3723007 100644 (file)
@@ -89,8 +89,8 @@ Value *llvm::stripIntegerCast(Value *V) {
 
 const SCEV *llvm::replaceSymbolicStrideSCEV(ScalarEvolution *SE,
                                             const ValueToValueMap &PtrToStride,
+                                            SCEVUnionPredicate &Preds,
                                             Value *Ptr, Value *OrigPtr) {
-
   const SCEV *OrigSCEV = SE->getSCEV(Ptr);
 
   // If there is an entry in the map return the SCEV of the pointer with the
@@ -108,22 +108,28 @@ const SCEV *llvm::replaceSymbolicStrideSCEV(ScalarEvolution *SE,
     ValueToValueMap RewriteMap;
     RewriteMap[StrideVal] = One;
 
-    const SCEV *ByOne =
-        SCEVParameterRewriter::rewrite(OrigSCEV, *SE, RewriteMap, true);
+    const auto *U = cast<SCEVUnknown>(SE->getSCEV(StrideVal));
+    const auto *CT =
+        static_cast<const SCEVConstant *>(SE->getOne(StrideVal->getType()));
+
+    Preds.add(SE->getEqualPredicate(U, CT));
+
+    const SCEV *ByOne = SE->rewriteUsingPredicate(OrigSCEV, Preds);
     DEBUG(dbgs() << "LAA: Replacing SCEV: " << *OrigSCEV << " by: " << *ByOne
                  << "\n");
     return ByOne;
   }
 
   // Otherwise, just return the SCEV of the original pointer.
-  return SE->getSCEV(Ptr);
+  return OrigSCEV;
 }
 
 void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr,
                                     unsigned DepSetId, unsigned ASId,
-                                    const ValueToValueMap &Strides) {
+                                    const ValueToValueMap &Strides,
+                                    SCEVUnionPredicate &Preds) {
   // Get the stride replaced scev.
-  const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Ptr);
+  const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr);
   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc);
   assert(AR && "Invalid addrec expression");
   const SCEV *Ex = SE->getBackedgeTakenCount(Lp);
@@ -417,9 +423,9 @@ public:
   typedef SmallPtrSet<MemAccessInfo, 8> MemAccessInfoSet;
 
   AccessAnalysis(const DataLayout &Dl, AliasAnalysis *AA, LoopInfo *LI,
-                 MemoryDepChecker::DepCandidates &DA)
-      : DL(Dl), AST(*AA), LI(LI), DepCands(DA),
-        IsRTCheckAnalysisNeeded(false) {}
+                 MemoryDepChecker::DepCandidates &DA, SCEVUnionPredicate &Preds)
+      : DL(Dl), AST(*AA), LI(LI), DepCands(DA), IsRTCheckAnalysisNeeded(false),
+        Preds(Preds) {}
 
   /// \brief Register a load  and whether it is only read from.
   void addLoad(MemoryLocation &Loc, bool IsReadOnly) {
@@ -504,14 +510,18 @@ private:
   /// (i.e. ShouldRetryWithRuntimeCheck), isDependencyCheckNeeded is cleared
   /// while this remains set if we have potentially dependent accesses.
   bool IsRTCheckAnalysisNeeded;
+
+  /// The SCEV predicate containing all the SCEV-related assumptions.
+  SCEVUnionPredicate &Preds;
 };
 
 } // end anonymous namespace
 
 /// \brief Check whether a pointer can participate in a runtime bounds check.
 static bool hasComputableBounds(ScalarEvolution *SE,
-                                const ValueToValueMap &Strides, Value *Ptr) {
-  const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Ptr);
+                                const ValueToValueMap &Strides, Value *Ptr,
+                                Loop *L, SCEVUnionPredicate &Preds) {
+  const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr);
   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
   if (!AR)
     return false;
@@ -554,11 +564,11 @@ bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck,
       else
         ++NumReadPtrChecks;
 
-      if (hasComputableBounds(SE, StridesMap, Ptr) &&
+      if (hasComputableBounds(SE, StridesMap, Ptr, TheLoop, Preds) &&
           // When we run after a failing dependency check we have to make sure
           // we don't have wrapping pointers.
           (!ShouldCheckStride ||
-           isStridedPtr(SE, Ptr, TheLoop, StridesMap) == 1)) {
+           isStridedPtr(SE, Ptr, TheLoop, StridesMap, Preds) == 1)) {
         // The id of the dependence set.
         unsigned DepId;
 
@@ -572,7 +582,7 @@ bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck,
           // Each access has its own dependence set.
           DepId = RunningDepId++;
 
-        RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap);
+        RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, Preds);
 
         DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n');
       } else {
@@ -803,7 +813,8 @@ static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR,
 
 /// \brief Check whether the access through \p Ptr has a constant stride.
 int llvm::isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp,
-                       const ValueToValueMap &StridesMap) {
+                       const ValueToValueMap &StridesMap,
+                       SCEVUnionPredicate &Preds) {
   Type *Ty = Ptr->getType();
   assert(Ty->isPointerTy() && "Unexpected non-ptr");
 
@@ -815,7 +826,7 @@ int llvm::isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp,
     return 0;
   }
 
-  const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Ptr);
+  const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Preds, Ptr);
 
   const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
   if (!AR) {
@@ -1026,11 +1037,11 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
       BPtr->getType()->getPointerAddressSpace())
     return Dependence::Unknown;
 
-  const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, APtr);
-  const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, BPtr);
+  const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, APtr);
+  const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, BPtr);
 
-  int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides);
-  int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides);
+  int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides, Preds);
+  int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides, Preds);
 
   const SCEV *Src = AScev;
   const SCEV *Sink = BScev;
@@ -1429,7 +1440,7 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) {
 
   MemoryDepChecker::DepCandidates DependentAccesses;
   AccessAnalysis Accesses(TheLoop->getHeader()->getModule()->getDataLayout(),
-                          AA, LI, DependentAccesses);
+                          AA, LI, DependentAccesses, Preds);
 
   // Holds the analyzed pointers. We don't want to call GetUnderlyingObjects
   // multiple times on the same object. If the ptr is accessed twice, once
@@ -1480,7 +1491,8 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) {
     // read a few words, modify, and write a few words, and some of the
     // words may be written to the same address.
     bool IsReadOnlyPtr = false;
-    if (Seen.insert(Ptr).second || !isStridedPtr(SE, Ptr, TheLoop, Strides)) {
+    if (Seen.insert(Ptr).second ||
+        !isStridedPtr(SE, Ptr, TheLoop, Strides, Preds)) {
       ++NumReads;
       IsReadOnlyPtr = true;
     }
@@ -1730,7 +1742,7 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
                                const TargetLibraryInfo *TLI, AliasAnalysis *AA,
                                DominatorTree *DT, LoopInfo *LI,
                                const ValueToValueMap &Strides)
-    : PtrRtChecking(SE), DepChecker(SE, L), TheLoop(L), SE(SE), DL(DL),
+    : PtrRtChecking(SE), DepChecker(SE, L, Preds), TheLoop(L), SE(SE), DL(DL),
       TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0),
       MaxSafeDepDistBytes(-1U), CanVecMem(false),
       StoreToLoopInvariantAddress(false) {
@@ -1765,6 +1777,9 @@ void LoopAccessInfo::print(raw_ostream &OS, unsigned Depth) const {
   OS.indent(Depth) << "Store to invariant address was "
                    << (StoreToLoopInvariantAddress ? "" : "not ")
                    << "found in loop.\n";
+
+  OS.indent(Depth) << "SCEV assumptions:\n";
+  Preds.print(OS, Depth);
 }
 
 const LoopAccessInfo &
@@ -1778,8 +1793,8 @@ LoopAccessAnalysis::getInfo(Loop *L, const ValueToValueMap &Strides) {
 
   if (!LAI) {
     const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
-    LAI = llvm::make_unique<LoopAccessInfo>(L, SE, DL, TLI, AA, DT, LI,
-                                            Strides);
+    LAI =
+        llvm::make_unique<LoopAccessInfo>(L, SE, DL, TLI, AA, DT, LI, Strides);
 #ifndef NDEBUG
     LAI->NumSymbolicStrides = Strides.size();
 #endif
index 66c5290b8a0aca18df202ea6f20f4de40d8335c1..53e63d70a3e887c55d2ad1b3cab0fe886e8ad934 100644 (file)
@@ -9093,6 +9093,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
       UnsignedRanges(std::move(Arg.UnsignedRanges)),
       SignedRanges(std::move(Arg.SignedRanges)),
       UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
+      UniquePreds(std::move(Arg.UniquePreds)),
       SCEVAllocator(std::move(Arg.SCEVAllocator)),
       FirstUnknown(Arg.FirstUnknown) {
   Arg.FirstUnknown = nullptr;
@@ -9596,3 +9597,134 @@ void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.addRequiredTransitive<DominatorTreeWrapperPass>();
   AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
 }
+
+const SCEVPredicate *
+ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS,
+                                   const SCEVConstant *RHS) {
+  FoldingSetNodeID ID;
+  // Unique this node based on the arguments
+  ID.AddInteger(SCEVPredicate::P_Equal);
+  ID.AddPointer(LHS);
+  ID.AddPointer(RHS);
+  void *IP = nullptr;
+  if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
+    return S;
+  SCEVEqualPredicate *Eq = new (SCEVAllocator)
+      SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS);
+  UniquePreds.InsertNode(Eq, IP);
+  return Eq;
+}
+
+class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
+public:
+  static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
+                             SCEVUnionPredicate &A) {
+    SCEVPredicateRewriter Rewriter(SE, A);
+    return Rewriter.visit(Scev);
+  }
+
+  SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P)
+      : SCEVRewriteVisitor(SE), P(P) {}
+
+  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+    auto ExprPreds = P.getPredicatesForExpr(Expr);
+    for (auto *Pred : ExprPreds)
+      if (const auto *IPred = dyn_cast<const SCEVEqualPredicate>(Pred))
+        if (IPred->getLHS() == Expr)
+          return IPred->getRHS();
+
+    return Expr;
+  }
+
+private:
+  SCEVUnionPredicate &P;
+};
+
+const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev,
+                                                   SCEVUnionPredicate &Preds) {
+  return SCEVPredicateRewriter::rewrite(Scev, *this, Preds);
+}
+
+/// SCEV predicates
+SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
+                             SCEVPredicateKind Kind)
+    : FastID(ID), Kind(Kind) {}
+
+SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID,
+                                       const SCEVUnknown *LHS,
+                                       const SCEVConstant *RHS)
+    : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {}
+
+bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const {
+  const auto *Op = dyn_cast<const SCEVEqualPredicate>(N);
+
+  if (!Op)
+    return false;
+
+  return Op->LHS == LHS && Op->RHS == RHS;
+}
+
+bool SCEVEqualPredicate::isAlwaysTrue() const { return false; }
+
+const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; }
+
+void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const {
+  OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
+}
+
+/// Union predicates don't get cached so create a dummy set ID for it.
+SCEVUnionPredicate::SCEVUnionPredicate()
+    : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
+
+bool SCEVUnionPredicate::isAlwaysTrue() const {
+  return std::all_of(Preds.begin(), Preds.end(),
+                     [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
+}
+
+ArrayRef<const SCEVPredicate *>
+SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) {
+  auto I = SCEVToPreds.find(Expr);
+  if (I == SCEVToPreds.end())
+    return ArrayRef<const SCEVPredicate *>();
+  return I->second;
+}
+
+bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
+  if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N))
+    return std::all_of(
+        Set->Preds.begin(), Set->Preds.end(),
+        [this](const SCEVPredicate *I) { return this->implies(I); });
+
+  auto ScevPredsIt = SCEVToPreds.find(N->getExpr());
+  if (ScevPredsIt == SCEVToPreds.end())
+    return false;
+  auto &SCEVPreds = ScevPredsIt->second;
+
+  return std::any_of(SCEVPreds.begin(), SCEVPreds.end(),
+                     [N](const SCEVPredicate *I) { return I->implies(N); });
+}
+
+const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; }
+
+void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
+  for (auto Pred : Preds)
+    Pred->print(OS, Depth);
+}
+
+void SCEVUnionPredicate::add(const SCEVPredicate *N) {
+  if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N)) {
+    for (auto Pred : Set->Preds)
+      add(Pred);
+    return;
+  }
+
+  if (implies(N))
+    return;
+
+  const SCEV *Key = N->getExpr();
+  assert(Key && "Only SCEVUnionPredicate doesn't have an "
+                " associated expression!");
+
+  SCEVToPreds[Key].push_back(N);
+  Preds.push_back(N);
+}
index 81316849847a0ca09a96366707d2ffd087db12de..2c2e5828003aea6dd76c7ce27108b0df39d16a1d 100644 (file)
@@ -1944,6 +1944,43 @@ bool SCEVExpander::isHighCostExpansionHelper(
   return false;
 }
 
+Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred,
+                                            Instruction *IP) {
+  assert(IP);
+  switch (Pred->getKind()) {
+  case SCEVPredicate::P_Union:
+    return expandUnionPredicate(cast<SCEVUnionPredicate>(Pred), IP);
+  case SCEVPredicate::P_Equal:
+    return expandEqualPredicate(cast<SCEVEqualPredicate>(Pred), IP);
+  }
+  llvm_unreachable("Unknown SCEV predicate type");
+}
+
+Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred,
+                                          Instruction *IP) {
+  Value *Expr0 = expandCodeFor(Pred->getLHS(), Pred->getLHS()->getType(), IP);
+  Value *Expr1 = expandCodeFor(Pred->getRHS(), Pred->getRHS()->getType(), IP);
+
+  Builder.SetInsertPoint(IP);
+  auto *I = Builder.CreateICmpNE(Expr0, Expr1, "ident.check");
+  return I;
+}
+
+Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union,
+                                          Instruction *IP) {
+  auto *BoolType = IntegerType::get(IP->getContext(), 1);
+  Value *Check = ConstantInt::getNullValue(BoolType);
+
+  // Loop over all checks in this set.
+  for (auto Pred : Union->getPredicates()) {
+    auto *NextCheck = expandCodeForPredicate(Pred, IP);
+    Builder.SetInsertPoint(IP);
+    Check = Builder.CreateOr(Check, NextCheck);
+  }
+
+  return Check;
+}
+
 namespace {
 // Search for a SCEV subexpression that is not safe to expand.  Any expression
 // that may expand to a !isSafeToSpeculativelyExecute value is unsafe, namely
index ae5ec8cb88a84411c7c22f81b98464daf431860c..b0cfbcc6777ec354a16c840bc6d58ade26f4234d 100644 (file)
@@ -222,6 +222,15 @@ static cl::opt<unsigned> PragmaVectorizeMemoryCheckThreshold(
     cl::desc("The maximum allowed number of runtime memory checks with a "
              "vectorize(enable) pragma."));
 
+static cl::opt<unsigned> VectorizeSCEVCheckThreshold(
+    "vectorize-scev-check-threshold", cl::init(16), cl::Hidden,
+    cl::desc("The maximum number of SCEV checks allowed."));
+
+static cl::opt<unsigned> PragmaVectorizeSCEVCheckThreshold(
+    "pragma-vectorize-scev-check-threshold", cl::init(128), cl::Hidden,
+    cl::desc("The maximum number of SCEV checks allowed with a "
+             "vectorize(enable) pragma"));
+
 namespace {
 
 // Forward declarations.
@@ -273,12 +282,12 @@ public:
   InnerLoopVectorizer(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI,
                       DominatorTree *DT, const TargetLibraryInfo *TLI,
                       const TargetTransformInfo *TTI, unsigned VecWidth,
-                      unsigned UnrollFactor)
+                      unsigned UnrollFactor, SCEVUnionPredicate &Preds)
       : OrigLoop(OrigLoop), SE(SE), LI(LI), DT(DT), TLI(TLI), TTI(TTI),
         VF(VecWidth), UF(UnrollFactor), Builder(SE->getContext()),
         Induction(nullptr), OldInduction(nullptr), WidenMap(UnrollFactor),
         TripCount(nullptr), VectorTripCount(nullptr), Legal(nullptr),
-        AddedSafetyChecks(false) {}
+        AddedSafetyChecks(false), Preds(Preds) {}
 
   // Perform the actual loop widening (vectorization).
   // MinimumBitWidths maps scalar integer values to the smallest bitwidth they
@@ -315,12 +324,6 @@ protected:
   typedef DenseMap<std::pair<BasicBlock*, BasicBlock*>,
                    VectorParts> EdgeMaskCache;
 
-  /// \brief Add checks for strides that were assumed to be 1.
-  ///
-  /// Returns the last check instruction and the first check instruction in the
-  /// pair as (first, last).
-  std::pair<Instruction *, Instruction *> addStrideCheck(Instruction *Loc);
-
   /// Create an empty loop, based on the loop ranges of the old loop.
   void createEmptyLoop();
   /// Create a new induction variable inside L.
@@ -404,11 +407,12 @@ protected:
   void emitMinimumIterationCountCheck(Loop *L, BasicBlock *Bypass);
   /// Emit a bypass check to see if the vector trip count is nonzero.
   void emitVectorLoopEnteredCheck(Loop *L, BasicBlock *Bypass);
-  /// Emit bypass checks to check if strides we've assumed to be one really are.
-  void emitStrideChecks(Loop *L, BasicBlock *Bypass);
+  /// Emit a bypass check to see if all of the SCEV assumptions we've
+  /// had to make are correct.
+  void emitSCEVChecks(Loop *L, BasicBlock *Bypass);
   /// Emit bypass checks to check any memory assumptions we may have made.
   void emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass);
-  
+
   /// This is a helper class that holds the vectorizer state. It maps scalar
   /// instructions to vector instructions. When the code is 'unrolled' then
   /// then a single scalar value is mapped to multiple vector parts. The parts
@@ -516,14 +520,23 @@ protected:
 
   // Record whether runtime check is added.
   bool AddedSafetyChecks;
+
+  /// The SCEV predicate containing all the SCEV-related assumptions.
+  /// The predicate is used to simplify existing expressions in the
+  /// context of existing SCEV assumptions. Since legality checking is
+  /// not done here, we don't need to use this predicate to record
+  /// further assumptions.
+  SCEVUnionPredicate &Preds;
 };
 
 class InnerLoopUnroller : public InnerLoopVectorizer {
 public:
   InnerLoopUnroller(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI,
                     DominatorTree *DT, const TargetLibraryInfo *TLI,
-                    const TargetTransformInfo *TTI, unsigned UnrollFactor)
-      : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor) {}
+                    const TargetTransformInfo *TTI, unsigned UnrollFactor,
+                    SCEVUnionPredicate &Preds)
+      : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor,
+                            Preds) {}
 
 private:
   void scalarizeInstruction(Instruction *Instr,
@@ -744,8 +757,9 @@ private:
 /// between the member and the group in a map.
 class InterleavedAccessInfo {
 public:
-  InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT)
-      : SE(SE), TheLoop(L), DT(DT) {}
+  InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT,
+                        SCEVUnionPredicate &Preds)
+      : SE(SE), TheLoop(L), DT(DT), Preds(Preds) {}
 
   ~InterleavedAccessInfo() {
     SmallSet<InterleaveGroup *, 4> DelSet;
@@ -779,6 +793,13 @@ private:
   Loop *TheLoop;
   DominatorTree *DT;
 
+  /// The SCEV predicate containing all the SCEV-related assumptions.
+  /// The predicate is used to simplify SCEV expressions in the
+  /// context of existing SCEV assumptions. The interleaved access
+  /// analysis can also add new predicates (for example by versioning
+  /// strides of pointers).
+  SCEVUnionPredicate &Preds;
+
   /// Holds the relationships between the members and the interleave group.
   DenseMap<Instruction *, InterleaveGroup *> InterleaveGroupMap;
 
@@ -1141,11 +1162,13 @@ public:
                             Function *F, const TargetTransformInfo *TTI,
                             LoopAccessAnalysis *LAA,
                             LoopVectorizationRequirements *R,
-                            const LoopVectorizeHints *H)
+                            const LoopVectorizeHints *H,
+                            SCEVUnionPredicate &Preds)
       : NumPredStores(0), TheLoop(L), SE(SE), TLI(TLI), TheFunction(F),
-        TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), InterleaveInfo(SE, L, DT),
-        Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false),
-        Requirements(R), Hints(H) {}
+        TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr),
+        InterleaveInfo(SE, L, DT, Preds), Induction(nullptr),
+        WidestIndTy(nullptr), HasFunNoNaNAttr(false), Requirements(R), Hints(H),
+        Preds(Preds) {}
 
   /// ReductionList contains the reduction descriptors for all
   /// of the reductions that were found in the loop.
@@ -1344,7 +1367,14 @@ private:
 
   /// While vectorizing these instructions we have to generate a
   /// call to the appropriate masked intrinsic
-  SmallPtrSet<const Instruction*, 8> MaskedOp;
+  SmallPtrSet<const Instruction *, 8> MaskedOp;
+
+  /// The SCEV predicate containing all the SCEV-related assumptions.
+  /// The predicate is used to simplify SCEV expressions in the
+  /// context of existing SCEV assumptions. The analysis will also
+  /// add a minimal set of new predicates if this is required to
+  /// enable vectorization/unrolling.
+  SCEVUnionPredicate &Preds;
 };
 
 /// LoopVectorizationCostModel - estimates the expected speedups due to
@@ -1360,9 +1390,10 @@ public:
                              LoopVectorizationLegality *Legal,
                              const TargetTransformInfo &TTI,
                              const TargetLibraryInfo *TLI, DemandedBits *DB,
-                             AssumptionCache *AC,
-                             const Function *F, const LoopVectorizeHints *Hints,
-                             SmallPtrSetImpl<const Value *> &ValuesToIgnore)
+                             AssumptionCache *AC, const Function *F,
+                             const LoopVectorizeHints *Hints,
+                             SmallPtrSetImpl<const Value *> &ValuesToIgnore,
+                             SCEVUnionPredicate &Preds)
       : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB),
         TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore) {}
 
@@ -1690,10 +1721,12 @@ struct LoopVectorize : public FunctionPass {
       }
     }
 
+    SCEVUnionPredicate Preds;
+
     // Check if it is legal to vectorize the loop.
     LoopVectorizationRequirements Requirements;
     LoopVectorizationLegality LVL(L, SE, DT, TLI, AA, F, TTI, LAA,
-                                  &Requirements, &Hints);
+                                  &Requirements, &Hints, Preds);
     if (!LVL.canVectorize()) {
       DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n");
       emitMissedWarning(F, L, Hints);
@@ -1712,7 +1745,7 @@ struct LoopVectorize : public FunctionPass {
 
     // Use the cost model.
     LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, DB, AC, F, &Hints,
-                                  ValuesToIgnore);
+                                  ValuesToIgnore, Preds);
 
     // Check the function attributes to find out if this function should be
     // optimized for size.
@@ -1823,7 +1856,7 @@ struct LoopVectorize : public FunctionPass {
       assert(IC > 1 && "interleave count should not be 1 or 0");
       // If we decided that it is not legal to vectorize the loop then
       // interleave it.
-      InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC);
+      InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC, Preds);
       Unroller.vectorize(&LVL, CM.MinBWs);
 
       emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(),
@@ -1831,7 +1864,7 @@ struct LoopVectorize : public FunctionPass {
                                  Twine(IC) + ")");
     } else {
       // If we decided that it is *legal* to vectorize the loop then do it.
-      InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC);
+      InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC, Preds);
       LB.vectorize(&LVL, CM.MinBWs);
       ++LoopsVectorized;
 
@@ -1992,7 +2025,7 @@ int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) {
     //  %idxprom = zext i32 %mul to i64  << Safe cast.
     //  %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom
     //
-    Last = replaceSymbolicStrideSCEV(SE, Strides,
+    Last = replaceSymbolicStrideSCEV(SE, Strides, Preds,
                                      Gep->getOperand(InductionOperand), Gep);
     if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(Last))
       Last =
@@ -2551,56 +2584,8 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, bool IfPredic
   }
 }
 
-static Instruction *getFirstInst(Instruction *FirstInst, Value *V,
-                                 Instruction *Loc) {
-  if (FirstInst)
-    return FirstInst;
-  if (Instruction *I = dyn_cast<Instruction>(V))
-    return I->getParent() == Loc->getParent() ? I : nullptr;
-  return nullptr;
-}
-
-std::pair<Instruction *, Instruction *>
-InnerLoopVectorizer::addStrideCheck(Instruction *Loc) {
-  Instruction *tnullptr = nullptr;
-  if (!Legal->mustCheckStrides())
-    return std::pair<Instruction *, Instruction *>(tnullptr, tnullptr);
-
-  IRBuilder<> ChkBuilder(Loc);
-
-  // Emit checks.
-  Value *Check = nullptr;
-  Instruction *FirstInst = nullptr;
-  for (SmallPtrSet<Value *, 8>::iterator SI = Legal->strides_begin(),
-                                         SE = Legal->strides_end();
-       SI != SE; ++SI) {
-    Value *Ptr = stripIntegerCast(*SI);
-    Value *C = ChkBuilder.CreateICmpNE(Ptr, ConstantInt::get(Ptr->getType(), 1),
-                                       "stride.chk");
-    // Store the first instruction we create.
-    FirstInst = getFirstInst(FirstInst, C, Loc);
-    if (Check)
-      Check = ChkBuilder.CreateOr(Check, C);
-    else
-      Check = C;
-  }
-
-  // We have to do this trickery because the IRBuilder might fold the check to a
-  // constant expression in which case there is no Instruction anchored in a
-  // the block.
-  LLVMContext &Ctx = Loc->getContext();
-  Instruction *TheCheck =
-      BinaryOperator::CreateAnd(Check, ConstantInt::getTrue(Ctx));
-  ChkBuilder.Insert(TheCheck, "stride.not.one");
-  FirstInst = getFirstInst(FirstInst, TheCheck, Loc);
-
-  return std::make_pair(FirstInst, TheCheck);
-}
-
-PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L,
-                                                      Value *Start,
-                                                      Value *End,
-                                                      Value *Step,
+PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L, Value *Start,
+                                                      Value *End, Value *Step,
                                                       Instruction *DL) {
   BasicBlock *Header = L->getHeader();
   BasicBlock *Latch = L->getLoopLatch();
@@ -2735,26 +2720,26 @@ void InnerLoopVectorizer::emitVectorLoopEnteredCheck(Loop *L,
   LoopBypassBlocks.push_back(BB);
 }
 
-void InnerLoopVectorizer::emitStrideChecks(Loop *L,
-                                           BasicBlock *Bypass) {
+void InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) {
   BasicBlock *BB = L->getLoopPreheader();
-  
-  // Generate the code to check that the strides we assumed to be one are really
-  // one. We want the new basic block to start at the first instruction in a
+
+  // Generate the code to check that the SCEV assumptions that we made.
+  // We want the new basic block to start at the first instruction in a
   // sequence of instructions that form a check.
-  Instruction *StrideCheck;
-  Instruction *FirstCheckInst;
-  std::tie(FirstCheckInst, StrideCheck) = addStrideCheck(BB->getTerminator());
-  if (!StrideCheck)
-    return;
+  SCEVExpander Exp(*SE, Bypass->getModule()->getDataLayout(), "scev.check");
+  Value *SCEVCheck = Exp.expandCodeForPredicate(&Preds, BB->getTerminator());
+
+  if (auto *C = dyn_cast<ConstantInt>(SCEVCheck))
+    if (C->isZero())
+      return;
 
   // Create a new block containing the stride check.
-  BB->setName("vector.stridecheck");
+  BB->setName("vector.scevcheck");
   auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph");
   if (L->getParentLoop())
     L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI);
   ReplaceInstWithInst(BB->getTerminator(),
-                      BranchInst::Create(Bypass, NewBB, StrideCheck));
+                      BranchInst::Create(Bypass, NewBB, SCEVCheck));
   LoopBypassBlocks.push_back(BB);
   AddedSafetyChecks = true;
 }
@@ -2874,10 +2859,10 @@ void InnerLoopVectorizer::createEmptyLoop() {
   // Now, compare the new count to zero. If it is zero skip the vector loop and
   // jump to the scalar loop.
   emitVectorLoopEnteredCheck(Lp, ScalarPH);
-  // Generate the code to check that the strides we assumed to be one are really
-  // one. We want the new basic block to start at the first instruction in a
-  // sequence of instructions that form a check.
-  emitStrideChecks(Lp, ScalarPH);
+  // Generate the code to check any assumptions that we've made for SCEV
+  // expressions.
+  emitSCEVChecks(Lp, ScalarPH);
+
   // Generate the code that checks in runtime if arrays overlap. We put the
   // checks into a separate block to make the more common case of few elements
   // faster.
@@ -4130,7 +4115,19 @@ bool LoopVectorizationLegality::canVectorize() {
 
   // Analyze interleaved memory accesses.
   if (UseInterleaved)
-     InterleaveInfo.analyzeInterleaving(Strides);
+    InterleaveInfo.analyzeInterleaving(Strides);
+
+  unsigned SCEVThreshold = VectorizeSCEVCheckThreshold;
+  if (Hints->getForce() == LoopVectorizeHints::FK_Enabled)
+    SCEVThreshold = PragmaVectorizeSCEVCheckThreshold;
+
+  if (Preds.getComplexity() > SCEVThreshold) {
+    emitAnalysis(VectorizationReport()
+                 << "Too many SCEV assumptions need to be made and checked "
+                 << "at runtime");
+    DEBUG(dbgs() << "LV: Too many SCEV checks needed.\n");
+    return false;
+  }
 
   // Okay! We can vectorize. At this point we don't have any other mem analysis
   // which may limit our maximum vectorization factor, so just return true with
@@ -4436,6 +4433,7 @@ bool LoopVectorizationLegality::canVectorizeMemory() {
   }
 
   Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks());
+  Preds.add(&LAI->Preds);
 
   return true;
 }
@@ -4550,7 +4548,7 @@ void InterleavedAccessInfo::collectConstStridedAccesses(
     StoreInst *SI = dyn_cast<StoreInst>(I);
 
     Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand();
-    int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides);
+    int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides, Preds);
 
     // The factor of the corresponding interleave group.
     unsigned Factor = std::abs(Stride);
@@ -4559,7 +4557,7 @@ void InterleavedAccessInfo::collectConstStridedAccesses(
     if (Factor < 2 || Factor > MaxInterleaveGroupFactor)
       continue;
 
-    const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Ptr);
+    const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr);
     PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType());
     unsigned Size = DL.getTypeAllocSize(PtrTy->getElementType());