Add support for memory runtime check. When we can, we calculate array bounds.
authorNadav Rotem <nrotem@apple.com>
Fri, 9 Nov 2012 07:09:44 +0000 (07:09 +0000)
committerNadav Rotem <nrotem@apple.com>
Fri, 9 Nov 2012 07:09:44 +0000 (07:09 +0000)
If the arrays are found to be disjoint then we run the vectorized version of
the loop. If they are not, we run the scalar code.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@167608 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Vectorize/LoopVectorize.cpp
test/Transforms/LoopVectorize/runtime-check.ll [new file with mode: 0644]

index 892808760f76574b3a6e5159049a651c2568e127..b657993e848bb95aaff5988fdae75bb5255f0d35 100644 (file)
@@ -78,6 +78,10 @@ VectorizationFactor("force-vector-width", cl::init(0), cl::Hidden,
 /// We don't vectorize loops with a known constant trip count below this number.
 const unsigned TinyTripCountThreshold = 16;
 
+/// When performing a runtime memory check, do not check more than this
+/// numner of pointers. Notice that the check is quadratic!
+const unsigned RuntimeMemoryCheckThreshold = 2;
+
 namespace {
 
 // Forward declarations.
@@ -242,6 +246,15 @@ public:
     ReductionKind Kind;
   };
 
+  // This POD struct holds information about the memory runtime legality
+  // check that a group of pointers do not overlap.
+  struct RuntimePointerCheck {
+    /// This flag indicates if we need to add the runtime check.
+    bool Need;
+    /// Holds the pointers that we need to check.
+    SmallVector<Value*, 2> Pointers;
+  };
+
   /// ReductionList contains the reduction descriptors for all
   /// of the reductions that were found in the loop.
   typedef DenseMap<PHINode*, ReductionDescriptor> ReductionList;
@@ -263,9 +276,14 @@ public:
   /// This check allows us to vectorize A[idx] into a wide load/store.
   bool isConsecutiveGep(Value *Ptr);
 
+  /// Returns true if the value V is uniform within the loop.
+  bool isUniform(Value *V);
+
   /// Returns true if this instruction will remain scalar after vectorization.
   bool isUniformAfterVectorization(Instruction* I) {return Uniforms.count(I);}
 
+  /// Returns the information that we collected about runtime memory check.
+  RuntimePointerCheck *getRuntimePointerCheck() {return &PtrRtCheck; }
 private:
   /// Check if a single basic block loop is vectorizable.
   /// At this point we know that this is a loop with a constant trip count
@@ -286,6 +304,8 @@ private:
   bool isReductionInstr(Instruction *I, ReductionKind Kind);
   /// Returns True, if 'Phi' is an induction variable.
   bool isInductionVariable(PHINode *Phi);
+  /// Return true if we
+  bool hasComputableBounds(Value *Ptr);
 
   /// The loop that we evaluate.
   Loop *TheLoop;
@@ -306,6 +326,9 @@ private:
   /// This set holds the variables which are known to be uniform after
   /// vectorization.
   SmallPtrSet<Instruction*, 4> Uniforms;
+  /// We need to check that all of the pointers in this list are disjoint
+  /// at runtime.
+  RuntimePointerCheck PtrRtCheck;
 };
 
 /// LoopVectorizationCostModel - estimates the expected speedups due to
@@ -506,6 +529,10 @@ bool LoopVectorizationLegality::isConsecutiveGep(Value *Ptr) {
   return false;
 }
 
+bool LoopVectorizationLegality::isUniform(Value *V) {
+  return (SE->isLoopInvariant(SE->getSCEV(V), TheLoop));
+}
+
 Value *SingleBlockLoopVectorizer::getVectorValue(Value *V) {
   assert(!V->getType()->isVectorTy() && "Can't widen a vector");
   // If we saved a vectorized copy of V, use it.
@@ -631,13 +658,29 @@ SingleBlockLoopVectorizer::createEmptyLoop(LoopVectorizationLegality *Legal) {
    ...
    */
 
+  OldInduction = Legal->getInduction();
+  assert(OldInduction && "We must have a single phi node.");
+  Type *IdxTy = OldInduction->getType();
+
+  // Find the loop boundaries.
+  const SCEV *ExitCount = SE->getExitCount(OrigLoop, OrigLoop->getHeader());
+  assert(ExitCount != SE->getCouldNotCompute() && "Invalid loop count");
+
+  // Get the total trip count from the count by adding 1.
+  ExitCount = SE->getAddExpr(ExitCount,
+                             SE->getConstant(ExitCount->getType(), 1));
+  // We may need to extend the index in case there is a type mismatch.
+  // We know that the count starts at zero and does not overflow.
+  // We are using Zext because it should be less expensive.
+  if (ExitCount->getType() != IdxTy)
+    ExitCount = SE->getZeroExtendExpr(ExitCount, IdxTy);
+
   // This is the original scalar-loop preheader.
   BasicBlock *BypassBlock = OrigLoop->getLoopPreheader();
   BasicBlock *ExitBlock = OrigLoop->getExitBlock();
   assert(ExitBlock && "Must have an exit block");
 
   // The loop index does not have to start at Zero. It starts with this value.
-  OldInduction = Legal->getInduction();
   Value *StartIdx = OldInduction->getIncomingValueForBlock(BypassBlock);
 
   assert(OrigLoop->getNumBlocks() == 1 && "Invalid loop");
@@ -655,8 +698,6 @@ SingleBlockLoopVectorizer::createEmptyLoop(LoopVectorizationLegality *Legal) {
                                  "scalar.preheader");
   // Find the induction variable.
   BasicBlock *OldBasicBlock = OrigLoop->getHeader();
-  assert(OldInduction && "We must have a single phi node.");
-  Type *IdxTy = OldInduction->getType();
 
   // Use this IR builder to create the loop instructions (Phi, Br, Cmp)
   // inside the loop.
@@ -666,25 +707,11 @@ SingleBlockLoopVectorizer::createEmptyLoop(LoopVectorizationLegality *Legal) {
   Induction = Builder.CreatePHI(IdxTy, 2, "index");
   Constant *Step = ConstantInt::get(IdxTy, VF);
 
-  // Find the loop boundaries.
-  const SCEV *ExitCount = SE->getExitCount(OrigLoop, OrigLoop->getHeader());
-  assert(ExitCount != SE->getCouldNotCompute() && "Invalid loop count");
-
-  // Get the total trip count from the count by adding 1.
-  ExitCount = SE->getAddExpr(ExitCount,
-                             SE->getConstant(ExitCount->getType(), 1));
-
   // Expand the trip count and place the new instructions in the preheader.
   // Notice that the pre-header does not change, only the loop body.
   SCEVExpander Exp(*SE, "induction");
   Instruction *Loc = BypassBlock->getTerminator();
 
-  // We may need to extend the index in case there is a type mismatch.
-  // We know that the count starts at zero and does not overflow.
-  // We are using Zext because it should be less expensive.
-  if (ExitCount->getType() != Induction->getType())
-    ExitCount = SE->getZeroExtendExpr(ExitCount, IdxTy);
-
   // Count holds the overall loop count (N).
   Value *Count = Exp.expandCodeFor(ExitCount, Induction->getType(), Loc);
 
@@ -704,15 +731,85 @@ SingleBlockLoopVectorizer::createEmptyLoop(LoopVectorizationLegality *Legal) {
                                IdxEndRoundDown,
                                StartIdx,
                                "cmp.zero", Loc);
+
+  LoopVectorizationLegality::RuntimePointerCheck *PtrRtCheck =
+    Legal->getRuntimePointerCheck();
+  Value *MemoryRuntimeCheck = 0;
+  if (PtrRtCheck->Need) {
+    unsigned NumPointers = PtrRtCheck->Pointers.size();
+    SmallVector<Value* , 2> Starts;
+    SmallVector<Value* , 2> Ends;
+
+    // Use this type for pointer arithmetic.
+    Type* PtrArithTy = PtrRtCheck->Pointers[0]->getType();
+
+    for (unsigned i=0; i < NumPointers; ++i) {
+      Value *Ptr = PtrRtCheck->Pointers[i];
+      const SCEV *Sc = SE->getSCEV(Ptr);
+
+      if (SE->isLoopInvariant(Sc, OrigLoop)) {
+        DEBUG(dbgs() << "LV1: Adding RT check for a loop invariant ptr:" <<
+              *Ptr <<"\n");
+        Starts.push_back(Ptr);
+        Ends.push_back(Ptr);
+      } else {
+        DEBUG(dbgs() << "LV: Adding RT check for range:" << *Ptr <<"\n");
+        const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc);
+        Value *Start = Exp.expandCodeFor(AR->getStart(), PtrArithTy, Loc);
+        const SCEV *Ex = SE->getExitCount(OrigLoop, OrigLoop->getHeader());
+        const SCEV *ScEnd = AR->evaluateAtIteration(Ex, *SE);
+        assert(!isa<SCEVCouldNotCompute>(ScEnd) && "Invalid scev range.");
+        Value *End = Exp.expandCodeFor(ScEnd, PtrArithTy, Loc);
+        Starts.push_back(Start);
+        Ends.push_back(End);
+      }
+    }
+
+    for (unsigned i=0; i < NumPointers; ++i) {
+      for (unsigned j=i+1; j < NumPointers; ++j) {
+        Value *Cmp0 = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE,
+                                      Starts[0], Ends[1], "bound0", Loc);
+        Value *Cmp1 = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE,
+                                      Starts[1], Ends[0], "bound1", Loc);
+        Value *IsConflict = BinaryOperator::Create(Instruction::And, Cmp0, Cmp1,
+                                                    "found.conflict", Loc);
+        if (MemoryRuntimeCheck) {
+          MemoryRuntimeCheck = BinaryOperator::Create(Instruction::Or,
+                                                      MemoryRuntimeCheck,
+                                                      IsConflict,
+                                                      "conflict.rdx", Loc);
+        } else {
+          MemoryRuntimeCheck = IsConflict;
+        }
+      }
+    }
+  }// end of need-runtime-check code.
+
+  // If we are using memory runtime checks, include them in.
+  if (MemoryRuntimeCheck) {
+    Cmp = BinaryOperator::Create(Instruction::Or, Cmp, MemoryRuntimeCheck,
+                                 "CntOrMem", Loc);
+  }
+
   BranchInst::Create(MiddleBlock, VectorPH, Cmp, Loc);
   // Remove the old terminator.
   Loc->eraseFromParent();
 
+  // We are going to resume the execution of the scalar loop.
+  // This PHI decides on what number to start. If we come from the
+  // vector loop then we need to start with the end index minus the
+  // index modulo VF. If we come from a bypass edge then we need to start
+  // from the real start.
+  PHINode* ResumeIndex = PHINode::Create(IdxTy, 2, "resume.idx",
+                                         MiddleBlock->getTerminator());
+  ResumeIndex->addIncoming(StartIdx, BypassBlock);
+  ResumeIndex->addIncoming(IdxEndRoundDown, VecBody);
+
   // Add a check in the middle block to see if we have completed
   // all of the iterations in the first vector loop.
   // If (N - N%VF) == N, then we *don't* need to run the remainder.
   Value *CmpN = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, IdxEnd,
-                                IdxEndRoundDown, "cmp.n",
+                                ResumeIndex, "cmp.n",
                                 MiddleBlock->getTerminator());
 
   BranchInst::Create(ExitBlock, ScalarPH, CmpN, MiddleBlock->getTerminator());
@@ -732,7 +829,7 @@ SingleBlockLoopVectorizer::createEmptyLoop(LoopVectorizationLegality *Legal) {
 
   // Fix the scalar body iteration count.
   unsigned BlockIdx = OldInduction->getBasicBlockIndex(ScalarPH);
-  OldInduction->setIncomingValue(BlockIdx, IdxEndRoundDown);
+  OldInduction->setIncomingValue(BlockIdx, ResumeIndex);
 
   // Get ready to start creating new instructions into the vectorized body.
   Builder.SetInsertPoint(VecBody->getFirstInsertionPt());
@@ -905,7 +1002,12 @@ SingleBlockLoopVectorizer::vectorizeLoop(LoopVectorizationLegality *Legal) {
         Type *StTy = VectorType::get(SI->getValueOperand()->getType(), VF);
         Value *Ptr = SI->getPointerOperand();
         unsigned Alignment = SI->getAlignment();
+
+        assert(!Legal->isUniform(Ptr) &&
+               "We do not allow storing to uniform addresses");
+
         GetElementPtrInst *Gep = dyn_cast<GetElementPtrInst>(Ptr);
+
         // This store does not use GEPs.
         if (!Legal->isConsecutiveGep(Gep)) {
           scalarizeInstruction(Inst);
@@ -935,8 +1037,9 @@ SingleBlockLoopVectorizer::vectorizeLoop(LoopVectorizationLegality *Legal) {
         unsigned Alignment = LI->getAlignment();
         GetElementPtrInst *Gep = dyn_cast<GetElementPtrInst>(Ptr);
 
-        // We don't have a gep. Scalarize the load.
-        if (!Legal->isConsecutiveGep(Gep)) {
+        // If we don't have a gep, or that the pointer is loop invariant,
+        // scalarize the load.
+        if (!Gep || Legal->isUniform(Gep) || !Legal->isConsecutiveGep(Gep)) {
           scalarizeInstruction(Inst);
           break;
         }
@@ -1146,12 +1249,6 @@ bool LoopVectorizationLegality::canVectorize() {
   BasicBlock *BB = TheLoop->getHeader();
   DEBUG(dbgs() << "LV: Found a loop: " << BB->getName() << "\n");
 
-  // Go over each instruction and look at memory deps.
-  if (!canVectorizeBlock(*BB)) {
-    DEBUG(dbgs() << "LV: Can't vectorize this loop header\n");
-    return false;
-  }
-
   // ScalarEvolution needs to be able to find the exit count.
   const SCEV *ExitCount = SE->getExitCount(TheLoop, BB);
   if (ExitCount == SE->getCouldNotCompute()) {
@@ -1167,7 +1264,15 @@ bool LoopVectorizationLegality::canVectorize() {
     return false;
   }
 
-  DEBUG(dbgs() << "LV: We can vectorize this loop!\n");
+  // Go over each instruction and look at memory deps.
+  if (!canVectorizeBlock(*BB)) {
+    DEBUG(dbgs() << "LV: Can't vectorize this loop header\n");
+    return false;
+  }
+
+  DEBUG(dbgs() << "LV: We can vectorize this loop" <<
+        (PtrRtCheck.Need ? " (with a runtime bound check)" : "")
+        <<"!\n");
 
   // Okay! We can vectorize. At this point we don't have any other mem analysis
   // which may limit our maximum vectorization factor, so just return true with
@@ -1304,6 +1409,8 @@ bool LoopVectorizationLegality::canVectorizeMemory(BasicBlock &BB) {
   // Holds the Load and Store *instructions*.
   ValueVector Loads;
   ValueVector Stores;
+  PtrRtCheck.Pointers.clear();
+  PtrRtCheck.Need = false;
 
   // Scan the BB and collect legal loads and stores.
   for (BasicBlock::iterator it = BB.begin(), e = BB.end(); it != e; ++it) {
@@ -1361,6 +1468,12 @@ bool LoopVectorizationLegality::canVectorizeMemory(BasicBlock &BB) {
     StoreInst *ST = dyn_cast<StoreInst>(*I);
     assert(ST && "Bad StoreInst");
     Value* Ptr = ST->getPointerOperand();
+
+    if (isUniform(Ptr)) {
+      DEBUG(dbgs() << "LV: We don't allow storing to uniform addresses\n");
+      return false;
+    }
+
     // If we did *not* see this pointer before, insert it to
     // the read-write list. At this phase it is only a 'write' list.
     if (Seen.insert(Ptr))
@@ -1390,6 +1503,39 @@ bool LoopVectorizationLegality::canVectorizeMemory(BasicBlock &BB) {
     return true;
   }
 
+  // Find pointers with computable bounds. We are going to use this information
+  // to place a runtime bound check.
+  bool RT = true;
+  for (I = ReadWrites.begin(), IE = ReadWrites.end(); I != IE; ++I)
+    if (hasComputableBounds(*I)) {
+      PtrRtCheck.Pointers.push_back(*I);
+      DEBUG(dbgs() << "LV: Found a runtime check ptr:" << **I <<"\n");
+    } else {
+      RT = false;
+      break;
+    }
+  for (I = Reads.begin(), IE = Reads.end(); I != IE; ++I)
+    if (hasComputableBounds(*I)) {
+      PtrRtCheck.Pointers.push_back(*I);
+      DEBUG(dbgs() << "LV: Found a runtime check ptr:" << **I <<"\n");
+    } else {
+      RT = false;
+      break;
+    }
+
+  // Check that we did not collect too many pointers or found a
+  // unsizeable pointer.
+  if (!RT || PtrRtCheck.Pointers.size() > RuntimeMemoryCheckThreshold) {
+    PtrRtCheck.Pointers.clear();
+    RT = false;
+  }
+
+  PtrRtCheck.Need = RT;
+
+  if (RT) {
+    DEBUG(dbgs() << "LV: We can perform a memory runtime check if needed.\n");
+  }
+
   // Now that the pointers are in two lists (Reads and ReadWrites), we
   // can check that there are no conflicts between each of the writes and
   // between the writes to the reads.
@@ -1404,12 +1550,12 @@ bool LoopVectorizationLegality::canVectorizeMemory(BasicBlock &BB) {
          it != e; ++it) {
       if (!isIdentifiedObject(*it)) {
         DEBUG(dbgs() << "LV: Found an unidentified write ptr:"<< **it <<"\n");
-        return false;
+        return RT;
       }
       if (!WriteObjects.insert(*it)) {
         DEBUG(dbgs() << "LV: Found a possible write-write reorder:"
               << **it <<"\n");
-        return false;
+        return RT;
       }
     }
     TempObjects.clear();
@@ -1422,18 +1568,21 @@ bool LoopVectorizationLegality::canVectorizeMemory(BasicBlock &BB) {
          it != e; ++it) {
       if (!isIdentifiedObject(*it)) {
         DEBUG(dbgs() << "LV: Found an unidentified read ptr:"<< **it <<"\n");
-        return false;
+        return RT;
       }
       if (WriteObjects.count(*it)) {
         DEBUG(dbgs() << "LV: Found a possible read/write reorder:"
               << **it <<"\n");
-        return false;
+        return RT;
       }
     }
     TempObjects.clear();
   }
 
-  // All is okay.
+  // It is safe to vectorize and we don't need any runtime checks.
+  DEBUG(dbgs() << "LV: We don't need a runtime memory check.\n");
+  PtrRtCheck.Pointers.clear();
+  PtrRtCheck.Need = false;
   return true;
 }
 
@@ -1556,6 +1705,15 @@ bool LoopVectorizationLegality::isInductionVariable(PHINode *Phi) {
   return true;
 }
 
+bool LoopVectorizationLegality::hasComputableBounds(Value *Ptr) {
+  const SCEV *PhiScev = SE->getSCEV(Ptr);
+  const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PhiScev);
+  if (!AR)
+    return false;
+
+  return AR->isAffine();
+}
+
 unsigned
 LoopVectorizationCostModel::findBestVectorizationFactor(unsigned VF) {
   if (!VTTI) {
diff --git a/test/Transforms/LoopVectorize/runtime-check.ll b/test/Transforms/LoopVectorize/runtime-check.ll
new file mode 100644 (file)
index 0000000..23933cf
--- /dev/null
@@ -0,0 +1,36 @@
+; RUN: opt < %s  -loop-vectorize -force-vector-width=4 -dce -instcombine -licm -S | FileCheck %s
+
+target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128"
+target triple = "x86_64-apple-macosx10.9.0"
+
+; Make sure we vectorize this loop:
+; int foo(float *a, float *b, int n) {
+;   for (int i=0; i<n; ++i)
+;     a[i] = b[i] * 3;
+; }
+
+;CHECK: load <4 x float>
+define i32 @foo(float* nocapture %a, float* nocapture %b, i32 %n) nounwind uwtable ssp {
+entry:
+  %cmp6 = icmp sgt i32 %n, 0
+  br i1 %cmp6, label %for.body, label %for.end
+
+for.body:                                         ; preds = %entry, %for.body
+  %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds float* %b, i64 %indvars.iv
+  %0 = load float* %arrayidx, align 4, !tbaa !0
+  %mul = fmul float %0, 3.000000e+00
+  %arrayidx2 = getelementptr inbounds float* %a, i64 %indvars.iv
+  store float %mul, float* %arrayidx2, align 4, !tbaa !0
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %lftr.wideiv = trunc i64 %indvars.iv.next to i32
+  %exitcond = icmp eq i32 %lftr.wideiv, %n
+  br i1 %exitcond, label %for.end, label %for.body
+
+for.end:                                          ; preds = %for.body, %entry
+  ret i32 undef
+}
+
+!0 = metadata !{metadata !"float", metadata !1}
+!1 = metadata !{metadata !"omnipotent char", metadata !2}
+!2 = metadata !{metadata !"Simple C/C++ TBAA"}