return nullptr;
}
+/// \brief IR Values for the lower and upper bounds of a pointer evolution.
+struct PointerBounds {
+ Value *Start;
+ Value *End;
+};
+
+/// \brief Expand code for the lower and upper bound of the pointer group \p CG
+/// in \p TheLoop. \return the values for the bounds.
+static PointerBounds
+expandBounds(const RuntimePointerChecking::CheckingPtrGroup *CG, Loop *TheLoop,
+ Instruction *Loc, SCEVExpander &Exp, ScalarEvolution *SE,
+ const RuntimePointerChecking &PtrRtChecking) {
+ Value *Ptr = PtrRtChecking.Pointers[CG->Members[0]].PointerValue;
+ const SCEV *Sc = SE->getSCEV(Ptr);
+
+ if (SE->isLoopInvariant(Sc, TheLoop)) {
+ DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:" << *Ptr
+ << "\n");
+ return {Ptr, Ptr};
+ } else {
+ unsigned AS = Ptr->getType()->getPointerAddressSpace();
+ LLVMContext &Ctx = Loc->getContext();
+
+ // Use this type for pointer arithmetic.
+ Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS);
+ Value *Start = nullptr, *End = nullptr;
+
+ DEBUG(dbgs() << "LAA: Adding RT check for range:\n");
+ Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc);
+ End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc);
+ DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High << "\n");
+ return {Start, End};
+ }
+}
+
+/// \brief Turns a collection of checks into a collection of expanded upper and
+/// lower bounds for both pointers in the check.
+static SmallVector<std::pair<PointerBounds, PointerBounds>, 4> expandBounds(
+ const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &PointerChecks,
+ Loop *L, Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp,
+ const RuntimePointerChecking &PtrRtChecking) {
+ SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds;
+
+ // Here we're relying on the SCEV Expander's cache to only emit code for the
+ // same bounds once.
+ std::transform(
+ PointerChecks.begin(), PointerChecks.end(),
+ std::back_inserter(ChecksWithBounds),
+ [&](const RuntimePointerChecking::PointerCheck &Check) {
+ return std::make_pair(
+ expandBounds(Check.first, L, Loc, Exp, SE, PtrRtChecking),
+ expandBounds(Check.second, L, Loc, Exp, SE, PtrRtChecking));
+ });
+
+ return ChecksWithBounds;
+}
+
std::pair<Instruction *, Instruction *> LoopAccessInfo::addRuntimeCheck(
- Instruction *Loc, const SmallVectorImpl<int> *PtrPartition) const {
- if (!PtrRtChecking.Need)
- return std::make_pair(nullptr, nullptr);
+ Instruction *Loc,
+ const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &PointerChecks)
+ const {
- SmallVector<TrackingVH<Value>, 2> Starts;
- SmallVector<TrackingVH<Value>, 2> Ends;
+ SCEVExpander Exp(*SE, DL, "induction");
+ auto ExpandedChecks =
+ expandBounds(PointerChecks, TheLoop, Loc, SE, Exp, PtrRtChecking);
LLVMContext &Ctx = Loc->getContext();
- SCEVExpander Exp(*SE, DL, "induction");
Instruction *FirstInst = nullptr;
-
- for (unsigned i = 0; i < PtrRtChecking.CheckingGroups.size(); ++i) {
- const RuntimePointerChecking::CheckingPtrGroup &CG =
- PtrRtChecking.CheckingGroups[i];
- Value *Ptr = PtrRtChecking.Pointers[CG.Members[0]].PointerValue;
- const SCEV *Sc = SE->getSCEV(Ptr);
-
- if (SE->isLoopInvariant(Sc, TheLoop)) {
- DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:" << *Ptr
- << "\n");
- Starts.push_back(Ptr);
- Ends.push_back(Ptr);
- } else {
- unsigned AS = Ptr->getType()->getPointerAddressSpace();
-
- // Use this type for pointer arithmetic.
- Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS);
- Value *Start = nullptr, *End = nullptr;
-
- DEBUG(dbgs() << "LAA: Adding RT check for range:\n");
- Start = Exp.expandCodeFor(CG.Low, PtrArithTy, Loc);
- End = Exp.expandCodeFor(CG.High, PtrArithTy, Loc);
- DEBUG(dbgs() << "Start: " << *CG.Low << " End: " << *CG.High << "\n");
- Starts.push_back(Start);
- Ends.push_back(End);
- }
- }
-
IRBuilder<> ChkBuilder(Loc);
// Our instructions might fold to a constant.
Value *MemoryRuntimeCheck = nullptr;
- for (unsigned i = 0; i < PtrRtChecking.CheckingGroups.size(); ++i) {
- for (unsigned j = i + 1; j < PtrRtChecking.CheckingGroups.size(); ++j) {
- const RuntimePointerChecking::CheckingPtrGroup &CGI =
- PtrRtChecking.CheckingGroups[i];
- const RuntimePointerChecking::CheckingPtrGroup &CGJ =
- PtrRtChecking.CheckingGroups[j];
- if (!PtrRtChecking.needsChecking(CGI, CGJ, PtrPartition))
- continue;
-
- unsigned AS0 = Starts[i]->getType()->getPointerAddressSpace();
- unsigned AS1 = Starts[j]->getType()->getPointerAddressSpace();
-
- assert((AS0 == Ends[j]->getType()->getPointerAddressSpace()) &&
- (AS1 == Ends[i]->getType()->getPointerAddressSpace()) &&
- "Trying to bounds check pointers with different address spaces");
-
- Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0);
- Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1);
-
- Value *Start0 = ChkBuilder.CreateBitCast(Starts[i], PtrArithTy0, "bc");
- Value *Start1 = ChkBuilder.CreateBitCast(Starts[j], PtrArithTy1, "bc");
- Value *End0 = ChkBuilder.CreateBitCast(Ends[i], PtrArithTy1, "bc");
- Value *End1 = ChkBuilder.CreateBitCast(Ends[j], PtrArithTy0, "bc");
-
- Value *Cmp0 = ChkBuilder.CreateICmpULE(Start0, End1, "bound0");
- FirstInst = getFirstInst(FirstInst, Cmp0, Loc);
- Value *Cmp1 = ChkBuilder.CreateICmpULE(Start1, End0, "bound1");
- FirstInst = getFirstInst(FirstInst, Cmp1, Loc);
- Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict");
+ for (const auto &Check : ExpandedChecks) {
+ const PointerBounds &A = Check.first, &B = Check.second;
+ unsigned AS0 = A.Start->getType()->getPointerAddressSpace();
+ unsigned AS1 = B.Start->getType()->getPointerAddressSpace();
+
+ assert((AS0 == B.End->getType()->getPointerAddressSpace()) &&
+ (AS1 == A.End->getType()->getPointerAddressSpace()) &&
+ "Trying to bounds check pointers with different address spaces");
+
+ Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0);
+ Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1);
+
+ Value *Start0 = ChkBuilder.CreateBitCast(A.Start, PtrArithTy0, "bc");
+ Value *Start1 = ChkBuilder.CreateBitCast(B.Start, PtrArithTy1, "bc");
+ Value *End0 = ChkBuilder.CreateBitCast(A.End, PtrArithTy1, "bc");
+ Value *End1 = ChkBuilder.CreateBitCast(B.End, PtrArithTy0, "bc");
+
+ Value *Cmp0 = ChkBuilder.CreateICmpULE(Start0, End1, "bound0");
+ FirstInst = getFirstInst(FirstInst, Cmp0, Loc);
+ Value *Cmp1 = ChkBuilder.CreateICmpULE(Start1, End0, "bound1");
+ FirstInst = getFirstInst(FirstInst, Cmp1, Loc);
+ Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict");
+ FirstInst = getFirstInst(FirstInst, IsConflict, Loc);
+ if (MemoryRuntimeCheck) {
+ IsConflict =
+ ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx");
FirstInst = getFirstInst(FirstInst, IsConflict, Loc);
- if (MemoryRuntimeCheck) {
- IsConflict = ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict,
- "conflict.rdx");
- FirstInst = getFirstInst(FirstInst, IsConflict, Loc);
- }
- MemoryRuntimeCheck = IsConflict;
}
+ MemoryRuntimeCheck = IsConflict;
}
if (!MemoryRuntimeCheck)
return std::make_pair(FirstInst, Check);
}
+std::pair<Instruction *, Instruction *> LoopAccessInfo::addRuntimeCheck(
+ Instruction *Loc, const SmallVectorImpl<int> *PtrPartition) const {
+ if (!PtrRtChecking.Need)
+ return std::make_pair(nullptr, nullptr);
+
+ SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks;
+ for (unsigned i = 0; i < PtrRtChecking.CheckingGroups.size(); ++i) {
+ for (unsigned j = i + 1; j < PtrRtChecking.CheckingGroups.size(); ++j) {
+ const RuntimePointerChecking::CheckingPtrGroup &CGI =
+ PtrRtChecking.CheckingGroups[i];
+ const RuntimePointerChecking::CheckingPtrGroup &CGJ =
+ PtrRtChecking.CheckingGroups[j];
+
+ if (PtrRtChecking.needsChecking(CGI, CGJ, PtrPartition))
+ Checks.push_back(std::make_pair(&CGI, &CGJ));
+ }
+ }
+
+ return addRuntimeCheck(Loc, Checks);
+}
+
LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
const DataLayout &DL,
const TargetLibraryInfo *TLI, AliasAnalysis *AA,