class ScalarEvolution;
class Loop;
class SCEV;
+class SCEVUnionPredicate;
/// Optimization analysis message produced during vectorization. Messages inform
/// the user why vectorization did not occur.
const SmallVectorImpl<Instruction *> &Instrs) const;
};
- MemoryDepChecker(ScalarEvolution *Se, const Loop *L)
+ MemoryDepChecker(ScalarEvolution *Se, const Loop *L,
+ SCEVUnionPredicate &Preds)
: SE(Se), InnermostLoop(L), AccessIdx(0),
ShouldRetryWithRuntimeCheck(false), SafeForVectorization(true),
- RecordInterestingDependences(true) {}
+ RecordInterestingDependences(true), Preds(Preds) {}
/// \brief Register the location (instructions are given increasing numbers)
/// of a write access.
/// \brief Check whether the data dependence could prevent store-load
/// forwarding.
bool couldPreventStoreLoadForward(unsigned Distance, unsigned TypeByteSize);
+
+ /// The SCEV predicate containing all the SCEV-related assumptions.
+ /// The dependence checker needs this in order to convert SCEVs of pointers
+ /// to more accurate expressions in the context of existing assumptions.
+ /// We also need this in case assumptions about SCEV expressions need to
+ /// be made in order to avoid unknown dependences. For example we might
+ /// assume a unit stride for a pointer in order to prove that a memory access
+ /// is strided and doesn't wrap.
+ SCEVUnionPredicate &Preds;
};
/// \brief Holds information about the memory runtime legality checks to verify
}
/// Insert a pointer and calculate the start and end SCEVs.
+ /// \p We need Preds in order to compute the SCEV expression of the pointer
+ /// according to the assumptions that we've made during the analysis.
+ /// The method might also version the pointer stride according to \p Strides,
+ /// and change \p Preds.
void insert(Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId,
- unsigned ASId, const ValueToValueMap &Strides);
+ unsigned ASId, const ValueToValueMap &Strides,
+ SCEVUnionPredicate &Preds);
/// \brief No run-time memory checking is necessary.
bool empty() const { return Pointers.empty(); }
return StoreToLoopInvariantAddress;
}
+ /// The SCEV predicate contains all the SCEV-related assumptions.
+ /// The is used to keep track of the minimal set of assumptions on SCEV
+ /// expressions that the analysis needs to make in order to return a
+ /// meaningful result. All SCEV expressions during the analysis should be
+ /// re-written (and therefore simplified) according to Preds.
+ /// A user of LoopAccessAnalysis will need to emit the runtime checks
+ /// associated with this predicate.
+ SCEVUnionPredicate Preds;
+
private:
/// \brief Analyze the loop. Substitute symbolic strides using Strides.
void analyzeLoop(const ValueToValueMap &Strides);
Value *stripIntegerCast(Value *V);
///\brief Return the SCEV corresponding to a pointer with the symbolic stride
-///replaced with constant one.
+/// replaced with constant one, assuming \p Preds is true.
+///
+/// If necessary this method will version the stride of the pointer according
+/// to \p PtrToStride and therefore add a new predicate to \p Preds.
///
/// If \p OrigPtr is not null, use it to look up the stride value instead of \p
/// Ptr. \p PtrToStride provides the mapping between the pointer value and its
/// stride as collected by LoopVectorizationLegality::collectStridedAccess.
const SCEV *replaceSymbolicStrideSCEV(ScalarEvolution *SE,
const ValueToValueMap &PtrToStride,
- Value *Ptr, Value *OrigPtr = nullptr);
+ SCEVUnionPredicate &Preds, Value *Ptr,
+ Value *OrigPtr = nullptr);
/// \brief Check the stride of the pointer and ensure that it does not wrap in
-/// the address space.
+/// the address space, assuming \p Preds is true.
+///
+/// If necessary this method will version the stride of the pointer according
+/// to \p PtrToStride and therefore add a new predicate to \p Preds.
int isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp,
- const ValueToValueMap &StridesMap);
+ const ValueToValueMap &StridesMap, SCEVUnionPredicate &Preds);
/// \brief This analysis provides dependence information for the memory accesses
/// of a loop.
class Loop;
class LoopInfo;
class Operator;
- class SCEVUnknown;
- class SCEVAddRecExpr;
class SCEV;
- template<> struct FoldingSetTrait<SCEV>;
+ class SCEVAddRecExpr;
+ class SCEVConstant;
+ class SCEVExpander;
+ class SCEVPredicate;
+ class SCEVUnknown;
+
+ template <> struct FoldingSetTrait<SCEV>;
+ template <> struct FoldingSetTrait<SCEVPredicate>;
/// This class represents an analyzed expression in the program. These are
/// opaque objects that the client is not allowed to do much with directly.
static bool classof(const SCEV *S);
};
+ /// SCEVPredicate - This class represents an assumption made using SCEV
+ /// expressions which can be checked at run-time.
+ class SCEVPredicate : public FoldingSetNode {
+ friend struct FoldingSetTrait<SCEVPredicate>;
+
+ /// A reference to an Interned FoldingSetNodeID for this node. The
+ /// ScalarEvolution's BumpPtrAllocator holds the data.
+ FoldingSetNodeIDRef FastID;
+
+ public:
+ enum SCEVPredicateKind { P_Union, P_Equal };
+
+ protected:
+ SCEVPredicateKind Kind;
+
+ public:
+ SCEVPredicate(const FoldingSetNodeIDRef ID, SCEVPredicateKind Kind);
+
+ virtual ~SCEVPredicate() {}
+
+ SCEVPredicateKind getKind() const { return Kind; }
+
+ /// \brief Returns the estimated complexity of this predicate.
+ /// This is roughly measured in the number of run-time checks required.
+ virtual unsigned getComplexity() { return 1; }
+
+ /// \brief Returns true if the predicate is always true. This means that no
+ /// assumptions were made and nothing needs to be checked at run-time.
+ virtual bool isAlwaysTrue() const = 0;
+
+ /// \brief Returns true if this predicate implies \p N.
+ virtual bool implies(const SCEVPredicate *N) const = 0;
+
+ /// \brief Prints a textual representation of this predicate with an
+ /// indentation of \p Depth.
+ virtual void print(raw_ostream &OS, unsigned Depth = 0) const = 0;
+
+ /// \brief Returns the SCEV to which this predicate applies, or nullptr
+ /// if this is a SCEVUnionPredicate.
+ virtual const SCEV *getExpr() const = 0;
+ };
+
+ inline raw_ostream &operator<<(raw_ostream &OS, const SCEVPredicate &P) {
+ P.print(OS);
+ return OS;
+ }
+
+ // Specialize FoldingSetTrait for SCEVPredicate to avoid needing to compute
+ // temporary FoldingSetNodeID values.
+ template <>
+ struct FoldingSetTrait<SCEVPredicate>
+ : DefaultFoldingSetTrait<SCEVPredicate> {
+
+ static void Profile(const SCEVPredicate &X, FoldingSetNodeID &ID) {
+ ID = X.FastID;
+ }
+
+ static bool Equals(const SCEVPredicate &X, const FoldingSetNodeID &ID,
+ unsigned IDHash, FoldingSetNodeID &TempID) {
+ return ID == X.FastID;
+ }
+ static unsigned ComputeHash(const SCEVPredicate &X,
+ FoldingSetNodeID &TempID) {
+ return X.FastID.ComputeHash();
+ }
+ };
+
+ /// SCEVEqualPredicate - This class represents an assumption that two SCEV
+ /// expressions are equal, and this can be checked at run-time. We assume
+ /// that the left hand side is a SCEVUnknown and the right hand side a
+ /// constant.
+ class SCEVEqualPredicate : public SCEVPredicate {
+ /// We assume that LHS == RHS, where LHS is a SCEVUnknown and RHS a
+ /// constant.
+ const SCEVUnknown *LHS;
+ const SCEVConstant *RHS;
+
+ public:
+ SCEVEqualPredicate(const FoldingSetNodeIDRef ID, const SCEVUnknown *LHS,
+ const SCEVConstant *RHS);
+
+ /// Implementation of the SCEVPredicate interface
+ bool implies(const SCEVPredicate *N) const override;
+ void print(raw_ostream &OS, unsigned Depth = 0) const override;
+ bool isAlwaysTrue() const override;
+ const SCEV *getExpr() const;
+
+ /// \brief Returns the left hand side of the equality.
+ const SCEVUnknown *getLHS() const { return LHS; }
+
+ /// \brief Returns the right hand side of the equality.
+ const SCEVConstant *getRHS() const { return RHS; }
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast:
+ static inline bool classof(const SCEVPredicate *P) {
+ return P->getKind() == P_Equal;
+ }
+ };
+
+ /// SCEVUnionPredicate - This class represents a composition of other
+ /// SCEV predicates, and is the class that most clients will interact with.
+ /// This is equivalent to a logical "AND" of all the predicates in the union.
+ class SCEVUnionPredicate : public SCEVPredicate {
+ private:
+ typedef DenseMap<const SCEV *, SmallVector<const SCEVPredicate *, 4>>
+ PredicateMap;
+
+ /// Vector with references to all predicates in this union.
+ SmallVector<const SCEVPredicate *, 16> Preds;
+ /// Maps SCEVs to predicates for quick look-ups.
+ PredicateMap SCEVToPreds;
+
+ public:
+ SCEVUnionPredicate();
+
+ const SmallVectorImpl<const SCEVPredicate *> &getPredicates() const {
+ return Preds;
+ }
+
+ /// \brief Adds a predicate to this union.
+ void add(const SCEVPredicate *N);
+
+ /// \brief Returns a reference to a vector containing all predicates
+ /// which apply to \p Expr.
+ ArrayRef<const SCEVPredicate *> getPredicatesForExpr(const SCEV *Expr);
+
+ /// Implementation of the SCEVPredicate interface
+ bool isAlwaysTrue() const override;
+ bool implies(const SCEVPredicate *N) const override;
+ void print(raw_ostream &OS, unsigned Depth) const;
+ const SCEV *getExpr() const override;
+
+ /// \brief We estimate the complexity of a union predicate as the size
+ /// number of predicates in the union.
+ unsigned getComplexity() override { return Preds.size(); }
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast:
+ static inline bool classof(const SCEVPredicate *P) {
+ return P->getKind() == P_Union;
+ }
+ };
+
/// The main scalar evolution driver. Because client code (intentionally)
/// can't do much with the SCEV objects directly, they must ask this class
/// for services.
return F.getParent()->getDataLayout();
}
+ const SCEVPredicate *getEqualPredicate(const SCEVUnknown *LHS,
+ const SCEVConstant *RHS);
+
+ /// Re-writes the SCEV according to the Predicates in \p Preds.
+ const SCEV *rewriteUsingPredicate(const SCEV *Scev, SCEVUnionPredicate &A);
+
private:
/// Compute the backedge taken count knowing the interval difference, the
/// stride and presence of the equality in the comparison.
private:
FoldingSet<SCEV> UniqueSCEVs;
+ FoldingSet<SCEVPredicate> UniquePreds;
BumpPtrAllocator SCEVAllocator;
/// The head of a linked list of all SCEVUnknown values that have been
/// block.
Value *expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I);
+ /// \brief Generates a code sequence that evaluates this predicate.
+ /// The inserted instructions will be at position \p Loc.
+ /// The result will be of type i1 and will have a value of 0 when the
+ /// predicate is false and 1 otherwise.
+ Value *expandCodeForPredicate(const SCEVPredicate *Pred, Instruction *Loc);
+
+ /// \brief A specialized variant of expandCodeForPredicate, handling the
+ /// case when we are expanding code for a SCEVEqualPredicate.
+ Value *expandEqualPredicate(const SCEVEqualPredicate *Pred,
+ Instruction *Loc);
+
+ /// \brief A specialized variant of expandCodeForPredicate, handling the
+ /// case when we are expanding code for a SCEVUnionPredicate.
+ Value *expandUnionPredicate(const SCEVUnionPredicate *Pred,
+ Instruction *Loc);
+
/// \brief Set the current IV increment loop and position.
void setIVIncInsertPos(const Loop *L, Instruction *Pos) {
assert(!CanonicalMode &&
const SCEV *llvm::replaceSymbolicStrideSCEV(ScalarEvolution *SE,
const ValueToValueMap &PtrToStride,
+ SCEVUnionPredicate &Preds,
Value *Ptr, Value *OrigPtr) {
-
const SCEV *OrigSCEV = SE->getSCEV(Ptr);
// If there is an entry in the map return the SCEV of the pointer with the
ValueToValueMap RewriteMap;
RewriteMap[StrideVal] = One;
- const SCEV *ByOne =
- SCEVParameterRewriter::rewrite(OrigSCEV, *SE, RewriteMap, true);
+ const auto *U = cast<SCEVUnknown>(SE->getSCEV(StrideVal));
+ const auto *CT =
+ static_cast<const SCEVConstant *>(SE->getOne(StrideVal->getType()));
+
+ Preds.add(SE->getEqualPredicate(U, CT));
+
+ const SCEV *ByOne = SE->rewriteUsingPredicate(OrigSCEV, Preds);
DEBUG(dbgs() << "LAA: Replacing SCEV: " << *OrigSCEV << " by: " << *ByOne
<< "\n");
return ByOne;
}
// Otherwise, just return the SCEV of the original pointer.
- return SE->getSCEV(Ptr);
+ return OrigSCEV;
}
void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr,
unsigned DepSetId, unsigned ASId,
- const ValueToValueMap &Strides) {
+ const ValueToValueMap &Strides,
+ SCEVUnionPredicate &Preds) {
// Get the stride replaced scev.
- const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Ptr);
+ const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr);
const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc);
assert(AR && "Invalid addrec expression");
const SCEV *Ex = SE->getBackedgeTakenCount(Lp);
typedef SmallPtrSet<MemAccessInfo, 8> MemAccessInfoSet;
AccessAnalysis(const DataLayout &Dl, AliasAnalysis *AA, LoopInfo *LI,
- MemoryDepChecker::DepCandidates &DA)
- : DL(Dl), AST(*AA), LI(LI), DepCands(DA),
- IsRTCheckAnalysisNeeded(false) {}
+ MemoryDepChecker::DepCandidates &DA, SCEVUnionPredicate &Preds)
+ : DL(Dl), AST(*AA), LI(LI), DepCands(DA), IsRTCheckAnalysisNeeded(false),
+ Preds(Preds) {}
/// \brief Register a load and whether it is only read from.
void addLoad(MemoryLocation &Loc, bool IsReadOnly) {
/// (i.e. ShouldRetryWithRuntimeCheck), isDependencyCheckNeeded is cleared
/// while this remains set if we have potentially dependent accesses.
bool IsRTCheckAnalysisNeeded;
+
+ /// The SCEV predicate containing all the SCEV-related assumptions.
+ SCEVUnionPredicate &Preds;
};
} // end anonymous namespace
/// \brief Check whether a pointer can participate in a runtime bounds check.
static bool hasComputableBounds(ScalarEvolution *SE,
- const ValueToValueMap &Strides, Value *Ptr) {
- const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Ptr);
+ const ValueToValueMap &Strides, Value *Ptr,
+ Loop *L, SCEVUnionPredicate &Preds) {
+ const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr);
const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
if (!AR)
return false;
else
++NumReadPtrChecks;
- if (hasComputableBounds(SE, StridesMap, Ptr) &&
+ if (hasComputableBounds(SE, StridesMap, Ptr, TheLoop, Preds) &&
// When we run after a failing dependency check we have to make sure
// we don't have wrapping pointers.
(!ShouldCheckStride ||
- isStridedPtr(SE, Ptr, TheLoop, StridesMap) == 1)) {
+ isStridedPtr(SE, Ptr, TheLoop, StridesMap, Preds) == 1)) {
// The id of the dependence set.
unsigned DepId;
// Each access has its own dependence set.
DepId = RunningDepId++;
- RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap);
+ RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, Preds);
DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n');
} else {
/// \brief Check whether the access through \p Ptr has a constant stride.
int llvm::isStridedPtr(ScalarEvolution *SE, Value *Ptr, const Loop *Lp,
- const ValueToValueMap &StridesMap) {
+ const ValueToValueMap &StridesMap,
+ SCEVUnionPredicate &Preds) {
Type *Ty = Ptr->getType();
assert(Ty->isPointerTy() && "Unexpected non-ptr");
return 0;
}
- const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Ptr);
+ const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Preds, Ptr);
const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev);
if (!AR) {
BPtr->getType()->getPointerAddressSpace())
return Dependence::Unknown;
- const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, APtr);
- const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, BPtr);
+ const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, APtr);
+ const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, Preds, BPtr);
- int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides);
- int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides);
+ int StrideAPtr = isStridedPtr(SE, APtr, InnermostLoop, Strides, Preds);
+ int StrideBPtr = isStridedPtr(SE, BPtr, InnermostLoop, Strides, Preds);
const SCEV *Src = AScev;
const SCEV *Sink = BScev;
MemoryDepChecker::DepCandidates DependentAccesses;
AccessAnalysis Accesses(TheLoop->getHeader()->getModule()->getDataLayout(),
- AA, LI, DependentAccesses);
+ AA, LI, DependentAccesses, Preds);
// Holds the analyzed pointers. We don't want to call GetUnderlyingObjects
// multiple times on the same object. If the ptr is accessed twice, once
// read a few words, modify, and write a few words, and some of the
// words may be written to the same address.
bool IsReadOnlyPtr = false;
- if (Seen.insert(Ptr).second || !isStridedPtr(SE, Ptr, TheLoop, Strides)) {
+ if (Seen.insert(Ptr).second ||
+ !isStridedPtr(SE, Ptr, TheLoop, Strides, Preds)) {
++NumReads;
IsReadOnlyPtr = true;
}
const TargetLibraryInfo *TLI, AliasAnalysis *AA,
DominatorTree *DT, LoopInfo *LI,
const ValueToValueMap &Strides)
- : PtrRtChecking(SE), DepChecker(SE, L), TheLoop(L), SE(SE), DL(DL),
+ : PtrRtChecking(SE), DepChecker(SE, L, Preds), TheLoop(L), SE(SE), DL(DL),
TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0),
MaxSafeDepDistBytes(-1U), CanVecMem(false),
StoreToLoopInvariantAddress(false) {
OS.indent(Depth) << "Store to invariant address was "
<< (StoreToLoopInvariantAddress ? "" : "not ")
<< "found in loop.\n";
+
+ OS.indent(Depth) << "SCEV assumptions:\n";
+ Preds.print(OS, Depth);
}
const LoopAccessInfo &
if (!LAI) {
const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
- LAI = llvm::make_unique<LoopAccessInfo>(L, SE, DL, TLI, AA, DT, LI,
- Strides);
+ LAI =
+ llvm::make_unique<LoopAccessInfo>(L, SE, DL, TLI, AA, DT, LI, Strides);
#ifndef NDEBUG
LAI->NumSymbolicStrides = Strides.size();
#endif
UnsignedRanges(std::move(Arg.UnsignedRanges)),
SignedRanges(std::move(Arg.SignedRanges)),
UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
+ UniquePreds(std::move(Arg.UniquePreds)),
SCEVAllocator(std::move(Arg.SCEVAllocator)),
FirstUnknown(Arg.FirstUnknown) {
Arg.FirstUnknown = nullptr;
AU.addRequiredTransitive<DominatorTreeWrapperPass>();
AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
}
+
+const SCEVPredicate *
+ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS,
+ const SCEVConstant *RHS) {
+ FoldingSetNodeID ID;
+ // Unique this node based on the arguments
+ ID.AddInteger(SCEVPredicate::P_Equal);
+ ID.AddPointer(LHS);
+ ID.AddPointer(RHS);
+ void *IP = nullptr;
+ if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
+ return S;
+ SCEVEqualPredicate *Eq = new (SCEVAllocator)
+ SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS);
+ UniquePreds.InsertNode(Eq, IP);
+ return Eq;
+}
+
+class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
+public:
+ static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
+ SCEVUnionPredicate &A) {
+ SCEVPredicateRewriter Rewriter(SE, A);
+ return Rewriter.visit(Scev);
+ }
+
+ SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P)
+ : SCEVRewriteVisitor(SE), P(P) {}
+
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ auto ExprPreds = P.getPredicatesForExpr(Expr);
+ for (auto *Pred : ExprPreds)
+ if (const auto *IPred = dyn_cast<const SCEVEqualPredicate>(Pred))
+ if (IPred->getLHS() == Expr)
+ return IPred->getRHS();
+
+ return Expr;
+ }
+
+private:
+ SCEVUnionPredicate &P;
+};
+
+const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev,
+ SCEVUnionPredicate &Preds) {
+ return SCEVPredicateRewriter::rewrite(Scev, *this, Preds);
+}
+
+/// SCEV predicates
+SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
+ SCEVPredicateKind Kind)
+ : FastID(ID), Kind(Kind) {}
+
+SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID,
+ const SCEVUnknown *LHS,
+ const SCEVConstant *RHS)
+ : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {}
+
+bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const {
+ const auto *Op = dyn_cast<const SCEVEqualPredicate>(N);
+
+ if (!Op)
+ return false;
+
+ return Op->LHS == LHS && Op->RHS == RHS;
+}
+
+bool SCEVEqualPredicate::isAlwaysTrue() const { return false; }
+
+const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; }
+
+void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const {
+ OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
+}
+
+/// Union predicates don't get cached so create a dummy set ID for it.
+SCEVUnionPredicate::SCEVUnionPredicate()
+ : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
+
+bool SCEVUnionPredicate::isAlwaysTrue() const {
+ return std::all_of(Preds.begin(), Preds.end(),
+ [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
+}
+
+ArrayRef<const SCEVPredicate *>
+SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) {
+ auto I = SCEVToPreds.find(Expr);
+ if (I == SCEVToPreds.end())
+ return ArrayRef<const SCEVPredicate *>();
+ return I->second;
+}
+
+bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
+ if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N))
+ return std::all_of(
+ Set->Preds.begin(), Set->Preds.end(),
+ [this](const SCEVPredicate *I) { return this->implies(I); });
+
+ auto ScevPredsIt = SCEVToPreds.find(N->getExpr());
+ if (ScevPredsIt == SCEVToPreds.end())
+ return false;
+ auto &SCEVPreds = ScevPredsIt->second;
+
+ return std::any_of(SCEVPreds.begin(), SCEVPreds.end(),
+ [N](const SCEVPredicate *I) { return I->implies(N); });
+}
+
+const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; }
+
+void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
+ for (auto Pred : Preds)
+ Pred->print(OS, Depth);
+}
+
+void SCEVUnionPredicate::add(const SCEVPredicate *N) {
+ if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N)) {
+ for (auto Pred : Set->Preds)
+ add(Pred);
+ return;
+ }
+
+ if (implies(N))
+ return;
+
+ const SCEV *Key = N->getExpr();
+ assert(Key && "Only SCEVUnionPredicate doesn't have an "
+ " associated expression!");
+
+ SCEVToPreds[Key].push_back(N);
+ Preds.push_back(N);
+}
return false;
}
+Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred,
+ Instruction *IP) {
+ assert(IP);
+ switch (Pred->getKind()) {
+ case SCEVPredicate::P_Union:
+ return expandUnionPredicate(cast<SCEVUnionPredicate>(Pred), IP);
+ case SCEVPredicate::P_Equal:
+ return expandEqualPredicate(cast<SCEVEqualPredicate>(Pred), IP);
+ }
+ llvm_unreachable("Unknown SCEV predicate type");
+}
+
+Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred,
+ Instruction *IP) {
+ Value *Expr0 = expandCodeFor(Pred->getLHS(), Pred->getLHS()->getType(), IP);
+ Value *Expr1 = expandCodeFor(Pred->getRHS(), Pred->getRHS()->getType(), IP);
+
+ Builder.SetInsertPoint(IP);
+ auto *I = Builder.CreateICmpNE(Expr0, Expr1, "ident.check");
+ return I;
+}
+
+Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union,
+ Instruction *IP) {
+ auto *BoolType = IntegerType::get(IP->getContext(), 1);
+ Value *Check = ConstantInt::getNullValue(BoolType);
+
+ // Loop over all checks in this set.
+ for (auto Pred : Union->getPredicates()) {
+ auto *NextCheck = expandCodeForPredicate(Pred, IP);
+ Builder.SetInsertPoint(IP);
+ Check = Builder.CreateOr(Check, NextCheck);
+ }
+
+ return Check;
+}
+
namespace {
// Search for a SCEV subexpression that is not safe to expand. Any expression
// that may expand to a !isSafeToSpeculativelyExecute value is unsafe, namely
cl::desc("The maximum allowed number of runtime memory checks with a "
"vectorize(enable) pragma."));
+static cl::opt<unsigned> VectorizeSCEVCheckThreshold(
+ "vectorize-scev-check-threshold", cl::init(16), cl::Hidden,
+ cl::desc("The maximum number of SCEV checks allowed."));
+
+static cl::opt<unsigned> PragmaVectorizeSCEVCheckThreshold(
+ "pragma-vectorize-scev-check-threshold", cl::init(128), cl::Hidden,
+ cl::desc("The maximum number of SCEV checks allowed with a "
+ "vectorize(enable) pragma"));
+
namespace {
// Forward declarations.
InnerLoopVectorizer(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI,
DominatorTree *DT, const TargetLibraryInfo *TLI,
const TargetTransformInfo *TTI, unsigned VecWidth,
- unsigned UnrollFactor)
+ unsigned UnrollFactor, SCEVUnionPredicate &Preds)
: OrigLoop(OrigLoop), SE(SE), LI(LI), DT(DT), TLI(TLI), TTI(TTI),
VF(VecWidth), UF(UnrollFactor), Builder(SE->getContext()),
Induction(nullptr), OldInduction(nullptr), WidenMap(UnrollFactor),
TripCount(nullptr), VectorTripCount(nullptr), Legal(nullptr),
- AddedSafetyChecks(false) {}
+ AddedSafetyChecks(false), Preds(Preds) {}
// Perform the actual loop widening (vectorization).
// MinimumBitWidths maps scalar integer values to the smallest bitwidth they
typedef DenseMap<std::pair<BasicBlock*, BasicBlock*>,
VectorParts> EdgeMaskCache;
- /// \brief Add checks for strides that were assumed to be 1.
- ///
- /// Returns the last check instruction and the first check instruction in the
- /// pair as (first, last).
- std::pair<Instruction *, Instruction *> addStrideCheck(Instruction *Loc);
-
/// Create an empty loop, based on the loop ranges of the old loop.
void createEmptyLoop();
/// Create a new induction variable inside L.
void emitMinimumIterationCountCheck(Loop *L, BasicBlock *Bypass);
/// Emit a bypass check to see if the vector trip count is nonzero.
void emitVectorLoopEnteredCheck(Loop *L, BasicBlock *Bypass);
- /// Emit bypass checks to check if strides we've assumed to be one really are.
- void emitStrideChecks(Loop *L, BasicBlock *Bypass);
+ /// Emit a bypass check to see if all of the SCEV assumptions we've
+ /// had to make are correct.
+ void emitSCEVChecks(Loop *L, BasicBlock *Bypass);
/// Emit bypass checks to check any memory assumptions we may have made.
void emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass);
-
+
/// This is a helper class that holds the vectorizer state. It maps scalar
/// instructions to vector instructions. When the code is 'unrolled' then
/// then a single scalar value is mapped to multiple vector parts. The parts
// Record whether runtime check is added.
bool AddedSafetyChecks;
+
+ /// The SCEV predicate containing all the SCEV-related assumptions.
+ /// The predicate is used to simplify existing expressions in the
+ /// context of existing SCEV assumptions. Since legality checking is
+ /// not done here, we don't need to use this predicate to record
+ /// further assumptions.
+ SCEVUnionPredicate &Preds;
};
class InnerLoopUnroller : public InnerLoopVectorizer {
public:
InnerLoopUnroller(Loop *OrigLoop, ScalarEvolution *SE, LoopInfo *LI,
DominatorTree *DT, const TargetLibraryInfo *TLI,
- const TargetTransformInfo *TTI, unsigned UnrollFactor)
- : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor) {}
+ const TargetTransformInfo *TTI, unsigned UnrollFactor,
+ SCEVUnionPredicate &Preds)
+ : InnerLoopVectorizer(OrigLoop, SE, LI, DT, TLI, TTI, 1, UnrollFactor,
+ Preds) {}
private:
void scalarizeInstruction(Instruction *Instr,
/// between the member and the group in a map.
class InterleavedAccessInfo {
public:
- InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT)
- : SE(SE), TheLoop(L), DT(DT) {}
+ InterleavedAccessInfo(ScalarEvolution *SE, Loop *L, DominatorTree *DT,
+ SCEVUnionPredicate &Preds)
+ : SE(SE), TheLoop(L), DT(DT), Preds(Preds) {}
~InterleavedAccessInfo() {
SmallSet<InterleaveGroup *, 4> DelSet;
Loop *TheLoop;
DominatorTree *DT;
+ /// The SCEV predicate containing all the SCEV-related assumptions.
+ /// The predicate is used to simplify SCEV expressions in the
+ /// context of existing SCEV assumptions. The interleaved access
+ /// analysis can also add new predicates (for example by versioning
+ /// strides of pointers).
+ SCEVUnionPredicate &Preds;
+
/// Holds the relationships between the members and the interleave group.
DenseMap<Instruction *, InterleaveGroup *> InterleaveGroupMap;
Function *F, const TargetTransformInfo *TTI,
LoopAccessAnalysis *LAA,
LoopVectorizationRequirements *R,
- const LoopVectorizeHints *H)
+ const LoopVectorizeHints *H,
+ SCEVUnionPredicate &Preds)
: NumPredStores(0), TheLoop(L), SE(SE), TLI(TLI), TheFunction(F),
- TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), InterleaveInfo(SE, L, DT),
- Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false),
- Requirements(R), Hints(H) {}
+ TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr),
+ InterleaveInfo(SE, L, DT, Preds), Induction(nullptr),
+ WidestIndTy(nullptr), HasFunNoNaNAttr(false), Requirements(R), Hints(H),
+ Preds(Preds) {}
/// ReductionList contains the reduction descriptors for all
/// of the reductions that were found in the loop.
/// While vectorizing these instructions we have to generate a
/// call to the appropriate masked intrinsic
- SmallPtrSet<const Instruction*, 8> MaskedOp;
+ SmallPtrSet<const Instruction *, 8> MaskedOp;
+
+ /// The SCEV predicate containing all the SCEV-related assumptions.
+ /// The predicate is used to simplify SCEV expressions in the
+ /// context of existing SCEV assumptions. The analysis will also
+ /// add a minimal set of new predicates if this is required to
+ /// enable vectorization/unrolling.
+ SCEVUnionPredicate &Preds;
};
/// LoopVectorizationCostModel - estimates the expected speedups due to
LoopVectorizationLegality *Legal,
const TargetTransformInfo &TTI,
const TargetLibraryInfo *TLI, DemandedBits *DB,
- AssumptionCache *AC,
- const Function *F, const LoopVectorizeHints *Hints,
- SmallPtrSetImpl<const Value *> &ValuesToIgnore)
+ AssumptionCache *AC, const Function *F,
+ const LoopVectorizeHints *Hints,
+ SmallPtrSetImpl<const Value *> &ValuesToIgnore,
+ SCEVUnionPredicate &Preds)
: TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB),
TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore) {}
}
}
+ SCEVUnionPredicate Preds;
+
// Check if it is legal to vectorize the loop.
LoopVectorizationRequirements Requirements;
LoopVectorizationLegality LVL(L, SE, DT, TLI, AA, F, TTI, LAA,
- &Requirements, &Hints);
+ &Requirements, &Hints, Preds);
if (!LVL.canVectorize()) {
DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n");
emitMissedWarning(F, L, Hints);
// Use the cost model.
LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, DB, AC, F, &Hints,
- ValuesToIgnore);
+ ValuesToIgnore, Preds);
// Check the function attributes to find out if this function should be
// optimized for size.
assert(IC > 1 && "interleave count should not be 1 or 0");
// If we decided that it is not legal to vectorize the loop then
// interleave it.
- InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC);
+ InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC, Preds);
Unroller.vectorize(&LVL, CM.MinBWs);
emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(),
Twine(IC) + ")");
} else {
// If we decided that it is *legal* to vectorize the loop then do it.
- InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC);
+ InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC, Preds);
LB.vectorize(&LVL, CM.MinBWs);
++LoopsVectorized;
// %idxprom = zext i32 %mul to i64 << Safe cast.
// %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom
//
- Last = replaceSymbolicStrideSCEV(SE, Strides,
+ Last = replaceSymbolicStrideSCEV(SE, Strides, Preds,
Gep->getOperand(InductionOperand), Gep);
if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(Last))
Last =
}
}
-static Instruction *getFirstInst(Instruction *FirstInst, Value *V,
- Instruction *Loc) {
- if (FirstInst)
- return FirstInst;
- if (Instruction *I = dyn_cast<Instruction>(V))
- return I->getParent() == Loc->getParent() ? I : nullptr;
- return nullptr;
-}
-
-std::pair<Instruction *, Instruction *>
-InnerLoopVectorizer::addStrideCheck(Instruction *Loc) {
- Instruction *tnullptr = nullptr;
- if (!Legal->mustCheckStrides())
- return std::pair<Instruction *, Instruction *>(tnullptr, tnullptr);
-
- IRBuilder<> ChkBuilder(Loc);
-
- // Emit checks.
- Value *Check = nullptr;
- Instruction *FirstInst = nullptr;
- for (SmallPtrSet<Value *, 8>::iterator SI = Legal->strides_begin(),
- SE = Legal->strides_end();
- SI != SE; ++SI) {
- Value *Ptr = stripIntegerCast(*SI);
- Value *C = ChkBuilder.CreateICmpNE(Ptr, ConstantInt::get(Ptr->getType(), 1),
- "stride.chk");
- // Store the first instruction we create.
- FirstInst = getFirstInst(FirstInst, C, Loc);
- if (Check)
- Check = ChkBuilder.CreateOr(Check, C);
- else
- Check = C;
- }
-
- // We have to do this trickery because the IRBuilder might fold the check to a
- // constant expression in which case there is no Instruction anchored in a
- // the block.
- LLVMContext &Ctx = Loc->getContext();
- Instruction *TheCheck =
- BinaryOperator::CreateAnd(Check, ConstantInt::getTrue(Ctx));
- ChkBuilder.Insert(TheCheck, "stride.not.one");
- FirstInst = getFirstInst(FirstInst, TheCheck, Loc);
-
- return std::make_pair(FirstInst, TheCheck);
-}
-
-PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L,
- Value *Start,
- Value *End,
- Value *Step,
+PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L, Value *Start,
+ Value *End, Value *Step,
Instruction *DL) {
BasicBlock *Header = L->getHeader();
BasicBlock *Latch = L->getLoopLatch();
LoopBypassBlocks.push_back(BB);
}
-void InnerLoopVectorizer::emitStrideChecks(Loop *L,
- BasicBlock *Bypass) {
+void InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) {
BasicBlock *BB = L->getLoopPreheader();
-
- // Generate the code to check that the strides we assumed to be one are really
- // one. We want the new basic block to start at the first instruction in a
+
+ // Generate the code to check that the SCEV assumptions that we made.
+ // We want the new basic block to start at the first instruction in a
// sequence of instructions that form a check.
- Instruction *StrideCheck;
- Instruction *FirstCheckInst;
- std::tie(FirstCheckInst, StrideCheck) = addStrideCheck(BB->getTerminator());
- if (!StrideCheck)
- return;
+ SCEVExpander Exp(*SE, Bypass->getModule()->getDataLayout(), "scev.check");
+ Value *SCEVCheck = Exp.expandCodeForPredicate(&Preds, BB->getTerminator());
+
+ if (auto *C = dyn_cast<ConstantInt>(SCEVCheck))
+ if (C->isZero())
+ return;
// Create a new block containing the stride check.
- BB->setName("vector.stridecheck");
+ BB->setName("vector.scevcheck");
auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph");
if (L->getParentLoop())
L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI);
ReplaceInstWithInst(BB->getTerminator(),
- BranchInst::Create(Bypass, NewBB, StrideCheck));
+ BranchInst::Create(Bypass, NewBB, SCEVCheck));
LoopBypassBlocks.push_back(BB);
AddedSafetyChecks = true;
}
// Now, compare the new count to zero. If it is zero skip the vector loop and
// jump to the scalar loop.
emitVectorLoopEnteredCheck(Lp, ScalarPH);
- // Generate the code to check that the strides we assumed to be one are really
- // one. We want the new basic block to start at the first instruction in a
- // sequence of instructions that form a check.
- emitStrideChecks(Lp, ScalarPH);
+ // Generate the code to check any assumptions that we've made for SCEV
+ // expressions.
+ emitSCEVChecks(Lp, ScalarPH);
+
// Generate the code that checks in runtime if arrays overlap. We put the
// checks into a separate block to make the more common case of few elements
// faster.
// Analyze interleaved memory accesses.
if (UseInterleaved)
- InterleaveInfo.analyzeInterleaving(Strides);
+ InterleaveInfo.analyzeInterleaving(Strides);
+
+ unsigned SCEVThreshold = VectorizeSCEVCheckThreshold;
+ if (Hints->getForce() == LoopVectorizeHints::FK_Enabled)
+ SCEVThreshold = PragmaVectorizeSCEVCheckThreshold;
+
+ if (Preds.getComplexity() > SCEVThreshold) {
+ emitAnalysis(VectorizationReport()
+ << "Too many SCEV assumptions need to be made and checked "
+ << "at runtime");
+ DEBUG(dbgs() << "LV: Too many SCEV checks needed.\n");
+ return false;
+ }
// Okay! We can vectorize. At this point we don't have any other mem analysis
// which may limit our maximum vectorization factor, so just return true with
}
Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks());
+ Preds.add(&LAI->Preds);
return true;
}
StoreInst *SI = dyn_cast<StoreInst>(I);
Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand();
- int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides);
+ int Stride = isStridedPtr(SE, Ptr, TheLoop, Strides, Preds);
// The factor of the corresponding interleave group.
unsigned Factor = std::abs(Stride);
if (Factor < 2 || Factor > MaxInterleaveGroupFactor)
continue;
- const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Ptr);
+ const SCEV *Scev = replaceSymbolicStrideSCEV(SE, Strides, Preds, Ptr);
PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType());
unsigned Size = DL.getTypeAllocSize(PtrTy->getElementType());