/// \brief Returns the estimated complexity of this predicate.
/// This is roughly measured in the number of run-time checks required.
- virtual unsigned getComplexity() { return 1; }
+ virtual unsigned getComplexity() const { return 1; }
/// \brief Returns true if the predicate is always true. This means that no
/// assumptions were made and nothing needs to be checked at run-time.
/// \brief We estimate the complexity of a union predicate as the size
/// number of predicates in the union.
- unsigned getComplexity() override { return Preds.size(); }
+ unsigned getComplexity() const override { return Preds.size(); }
/// Methods for support type inquiry through isa, cast, and dyn_cast:
static inline bool classof(const SCEVPredicate *P) {
#define LLVM_TRANSFORMS_UTILS_LOOPVERSIONING_H
#include "llvm/Analysis/LoopAccessAnalysis.h"
+#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
class Loop;
class LoopAccessInfo;
class LoopInfo;
+class ScalarEvolution;
/// \brief This class emits a version of the loop where run-time checks ensure
/// that may-alias pointers can't overlap.
/// already has a preheader.
class LoopVersioning {
public:
- /// \brief Expects MemCheck, LoopAccessInfo, Loop, LoopInfo, DominatorTree
- /// as input. It uses runtime check provided by user.
- LoopVersioning(SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks,
- const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI,
- DominatorTree *DT);
-
/// \brief Expects LoopAccessInfo, Loop, LoopInfo, DominatorTree as input.
- /// It uses default runtime check provided by LoopAccessInfo.
- LoopVersioning(const LoopAccessInfo &LAInfo, Loop *L, LoopInfo *LI,
- DominatorTree *DT);
+ /// It uses runtime check provided by the user. If \p UseLAIChecks is true,
+ /// we will retain the default checks made by LAI. Otherwise, construct an
+ /// object having no checks and we expect the user to add them.
+ LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI,
+ DominatorTree *DT, ScalarEvolution *SE,
+ bool UseLAIChecks = true);
/// \brief Performs the CFG manipulation part of versioning the loop including
/// the DominatorTree and LoopInfo updates.
/// loop may alias (i.e. one of the memchecks failed).
Loop *getNonVersionedLoop() { return NonVersionedLoop; }
+ /// \brief Sets the runtime alias checks for versioning the loop.
+ void setAliasChecks(
+ const SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks);
+
+ /// \brief Sets the runtime SCEV checks for versioning the loop.
+ void setSCEVChecks(SCEVUnionPredicate Check);
+
private:
/// \brief Adds the necessary PHI nodes for the versioned loops based on the
/// loop-defined values used outside of the loop.
/// in NonVersionedLoop.
ValueToValueMapTy VMap;
- /// \brief The set of checks that we are versioning for.
- SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks;
+ /// \brief The set of alias checks that we are versioning for.
+ SmallVector<RuntimePointerChecking::PointerCheck, 4> AliasChecks;
+
+ /// \brief The set of SCEV checks that we are versioning for.
+ SCEVUnionPredicate Preds;
/// \brief Analyses used.
const LoopAccessInfo &LAI;
LoopInfo *LI;
DominatorTree *DT;
+ ScalarEvolution *SE;
};
}
"if-convertible by the loop vectorizer"),
cl::init(false));
+static cl::opt<unsigned> DistributeSCEVCheckThreshold(
+ "loop-distribute-scev-check-threshold", cl::init(8), cl::Hidden,
+ cl::desc("The maximum number of SCEV checks allowed for Loop "
+ "Distribution"));
+
STATISTIC(NumLoopsDistributed, "Number of loops distributed");
namespace {
LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
LAA = &getAnalysis<LoopAccessAnalysis>();
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
// Build up a worklist of inner-loops to vectorize. This is necessary as the
// act of distributing a loop creates new loops and can invalidate iterators
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<ScalarEvolutionWrapperPass>();
AU.addRequired<LoopInfoWrapperPass>();
AU.addPreserved<LoopInfoWrapperPass>();
AU.addRequired<LoopAccessAnalysis>();
return false;
}
+ // Don't distribute the loop if we need too many SCEV run-time checks.
+ const SCEVUnionPredicate &Pred = LAI.Preds;
+ if (Pred.getComplexity() > DistributeSCEVCheckThreshold) {
+ DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n");
+ return false;
+ }
+
DEBUG(dbgs() << "\nDistributing loop: " << *L << "\n");
// We're done forming the partitions set up the reverse mapping from
// instructions to partitions.
if (!PH->getSinglePredecessor() || &*PH->begin() != PH->getTerminator())
SplitBlock(PH, PH->getTerminator(), DT, LI);
- // If we need run-time checks to disambiguate pointers are run-time, version
- // the loop now.
+ // If we need run-time checks, version the loop now.
auto PtrToPartition = Partitions.computePartitionSetForPointers(LAI);
const auto *RtPtrChecking = LAI.getRuntimePointerChecking();
const auto &AllChecks = RtPtrChecking->getChecks();
auto Checks = includeOnlyCrossPartitionChecks(AllChecks, PtrToPartition,
RtPtrChecking);
- if (!Checks.empty()) {
+
+ if (!Pred.isAlwaysTrue() || !Checks.empty()) {
DEBUG(dbgs() << "\nPointers:\n");
DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks));
- LoopVersioning LVer(std::move(Checks), LAI, L, LI, DT);
+ LoopVersioning LVer(LAI, L, LI, DT, SE, false);
+ LVer.setAliasChecks(std::move(Checks));
+ LVer.setSCEVChecks(LAI.Preds);
LVer.versionLoop(DefsUsedOutside);
}
LoopInfo *LI;
LoopAccessAnalysis *LAA;
DominatorTree *DT;
+ ScalarEvolution *SE;
};
} // anonymous namespace
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(LoopAccessAnalysis)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
INITIALIZE_PASS_END(LoopDistribute, LDIST_NAME, ldist_name, false, false)
namespace llvm {
cl::desc("Max number of memchecks allowed per eliminated load on average"),
cl::init(1));
+static cl::opt<unsigned> LoadElimSCEVCheckThreshold(
+ "loop-load-elimination-scev-check-threshold", cl::init(8), cl::Hidden,
+ cl::desc("The maximum number of SCEV checks allowed for Loop "
+ "Load Elimination"));
+
+
STATISTIC(NumLoopLoadEliminted, "Number of loads eliminated by LLE");
namespace {
return false;
}
+ if (LAI.Preds.getComplexity() > LoadElimSCEVCheckThreshold) {
+ DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n");
+ return false;
+ }
+
// Point of no-return, start the transformation. First, version the loop if
// necessary.
- if (!Checks.empty()) {
- LoopVersioning LV(std::move(Checks), LAI, L, LI, DT);
+ if (!Checks.empty() || !LAI.Preds.isAlwaysTrue()) {
+ LoopVersioning LV(LAI, L, LI, DT, SE, false);
+ LV.setAliasChecks(std::move(Checks));
+ LV.setSCEVChecks(LAI.Preds);
LV.versionLoop();
}
#include "llvm/Analysis/LoopAccessAnalysis.h"
#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ScalarEvolutionExpander.h"
#include "llvm/IR/Dominators.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
using namespace llvm;
-LoopVersioning::LoopVersioning(
- SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks,
- const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, DominatorTree *DT)
- : VersionedLoop(L), NonVersionedLoop(nullptr), Checks(std::move(Checks)),
- LAI(LAI), LI(LI), DT(DT) {
+LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI,
+ DominatorTree *DT, ScalarEvolution *SE,
+ bool UseLAIChecks)
+ : VersionedLoop(L), NonVersionedLoop(nullptr), LAI(LAI), LI(LI), DT(DT),
+ SE(SE) {
assert(L->getExitBlock() && "No single exit block");
assert(L->getLoopPreheader() && "No preheader");
+ if (UseLAIChecks) {
+ setAliasChecks(LAI.getRuntimePointerChecking()->getChecks());
+ setSCEVChecks(LAI.Preds);
+ }
}
-LoopVersioning::LoopVersioning(const LoopAccessInfo &LAInfo, Loop *L,
- LoopInfo *LI, DominatorTree *DT)
- : VersionedLoop(L), NonVersionedLoop(nullptr),
- Checks(LAInfo.getRuntimePointerChecking()->getChecks()), LAI(LAInfo),
- LI(LI), DT(DT) {
- assert(L->getExitBlock() && "No single exit block");
- assert(L->getLoopPreheader() && "No preheader");
+void LoopVersioning::setAliasChecks(
+ const SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks) {
+ AliasChecks = std::move(Checks);
+}
+
+void LoopVersioning::setSCEVChecks(SCEVUnionPredicate Check) {
+ Preds = std::move(Check);
}
void LoopVersioning::versionLoop(
const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
Instruction *FirstCheckInst;
Instruction *MemRuntimeCheck;
+ Value *SCEVRuntimeCheck;
+ Value *RuntimeCheck = nullptr;
+
// Add the memcheck in the original preheader (this is empty initially).
- BasicBlock *MemCheckBB = VersionedLoop->getLoopPreheader();
+ BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader();
std::tie(FirstCheckInst, MemRuntimeCheck) =
- LAI.addRuntimeChecks(MemCheckBB->getTerminator(), Checks);
+ LAI.addRuntimeChecks(RuntimeCheckBB->getTerminator(), AliasChecks);
assert(MemRuntimeCheck && "called even though needsAnyChecking = false");
+ const SCEVUnionPredicate &Pred = LAI.Preds;
+ SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
+ "scev.check");
+ SCEVRuntimeCheck =
+ Exp.expandCodeForPredicate(&Pred, RuntimeCheckBB->getTerminator());
+ auto *CI = dyn_cast<ConstantInt>(SCEVRuntimeCheck);
+
+ // Discard the SCEV runtime check if it is always true.
+ if (CI && CI->isZero())
+ SCEVRuntimeCheck = nullptr;
+
+ if (MemRuntimeCheck && SCEVRuntimeCheck) {
+ RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck,
+ SCEVRuntimeCheck, "ldist.safe");
+ if (auto *I = dyn_cast<Instruction>(RuntimeCheck))
+ I->insertBefore(RuntimeCheckBB->getTerminator());
+ } else
+ RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck;
+
+ assert(RuntimeCheck && "called even though we don't need "
+ "any runtime checks");
+
// Rename the block to make the IR more readable.
- MemCheckBB->setName(VersionedLoop->getHeader()->getName() + ".lver.memcheck");
+ RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() +
+ ".lver.check");
// Create empty preheader for the loop (and after cloning for the
// non-versioned loop).
- BasicBlock *PH = SplitBlock(MemCheckBB, MemCheckBB->getTerminator(), DT, LI);
+ BasicBlock *PH =
+ SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI);
PH->setName(VersionedLoop->getHeader()->getName() + ".ph");
// Clone the loop including the preheader.
// block is a join between the two loops.
SmallVector<BasicBlock *, 8> NonVersionedLoopBlocks;
NonVersionedLoop =
- cloneLoopWithPreheader(PH, MemCheckBB, VersionedLoop, VMap, ".lver.orig",
- LI, DT, NonVersionedLoopBlocks);
+ cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap,
+ ".lver.orig", LI, DT, NonVersionedLoopBlocks);
remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap);
// Insert the conditional branch based on the result of the memchecks.
- Instruction *OrigTerm = MemCheckBB->getTerminator();
+ Instruction *OrigTerm = RuntimeCheckBB->getTerminator();
BranchInst::Create(NonVersionedLoop->getLoopPreheader(),
- VersionedLoop->getLoopPreheader(), MemRuntimeCheck,
- OrigTerm);
+ VersionedLoop->getLoopPreheader(), RuntimeCheck, OrigTerm);
OrigTerm->eraseFromParent();
// The loops merge in the original exit block. This is now dominated by the
// memchecking block.
- DT->changeImmediateDominator(VersionedLoop->getExitBlock(), MemCheckBB);
+ DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB);
// Adds the necessary PHI nodes for the versioned loops based on the
// loop-defined values used outside of the loop.
; Since the checks to A and A + 4 get merged, this will give us a
; total of 8 compares.
;
-; CHECK: for.body.lver.memcheck:
+; CHECK: for.body.lver.check:
; CHECK: = icmp
; CHECK: = icmp
define void @f(i32* %A, i32* %B, i32* %C, i64 %N) {
-; CHECK: for.body.lver.memcheck:
+; CHECK: for.body.lver.check:
; CHECK: %found.conflict{{.*}} =
; CHECK-NOT: %found.conflict{{.*}} =
entry:
br label %for.body
-; AGGRESSIVE: for.body.lver.memcheck:
+; AGGRESSIVE: for.body.lver.check:
; AGGRESSIVE: %found.conflict{{.*}} =
; AGGRESSIVE: %found.conflict{{.*}} =
; AGGRESSIVE-NOT: %found.conflict{{.*}} =