+namespace {
+ // Constant strides come first which in turns are sorted by their absolute
+ // values. If absolute values are the same, then positive strides comes first.
+ // e.g.
+ // 4, -1, X, 1, 2 ==> 1, -1, 2, 4, X
+ struct StrideCompare {
+ bool operator()(const SCEVHandle &LHS, const SCEVHandle &RHS) {
+ SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS);
+ SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
+ if (LHSC && RHSC) {
+ int64_t LV = LHSC->getValue()->getSExtValue();
+ int64_t RV = RHSC->getValue()->getSExtValue();
+ uint64_t ALV = (LV < 0) ? -LV : LV;
+ uint64_t ARV = (RV < 0) ? -RV : RV;
+ if (ALV == ARV)
+ return LV > RV;
+ else
+ return ALV < ARV;
+ }
+ return (LHSC && !RHSC);
+ }
+ };
+}
+
+/// ChangeCompareStride - If a loop termination compare instruction is the
+/// only use of its stride, and the compaison is against a constant value,
+/// try eliminate the stride by moving the compare instruction to another
+/// stride and change its constant operand accordingly. e.g.
+///
+/// loop:
+/// ...
+/// v1 = v1 + 3
+/// v2 = v2 + 1
+/// if (v2 < 10) goto loop
+/// =>
+/// loop:
+/// ...
+/// v1 = v1 + 3
+/// if (v1 < 30) goto loop
+ICmpInst *LoopStrengthReduce::ChangeCompareStride(Loop *L, ICmpInst *Cond,
+ IVStrideUse* &CondUse,
+ const SCEVHandle* &CondStride) {
+ if (StrideOrder.size() < 2 ||
+ IVUsesByStride[*CondStride].Users.size() != 1)
+ return Cond;
+ const SCEVConstant *SC = dyn_cast<SCEVConstant>(*CondStride);
+ if (!SC) return Cond;
+ ConstantInt *C = dyn_cast<ConstantInt>(Cond->getOperand(1));
+ if (!C) return Cond;
+
+ ICmpInst::Predicate Predicate = Cond->getPredicate();
+ int64_t CmpSSInt = SC->getValue()->getSExtValue();
+ int64_t CmpVal = C->getValue().getSExtValue();
+ unsigned BitWidth = C->getValue().getBitWidth();
+ uint64_t SignBit = 1ULL << (BitWidth-1);
+ const Type *CmpTy = C->getType();
+ const Type *NewCmpTy = NULL;
+ unsigned TyBits = CmpTy->getPrimitiveSizeInBits();
+ unsigned NewTyBits = 0;
+ int64_t NewCmpVal = CmpVal;
+ SCEVHandle *NewStride = NULL;
+ Value *NewIncV = NULL;
+ int64_t Scale = 1;
+
+ // Look for a suitable stride / iv as replacement.
+ std::stable_sort(StrideOrder.begin(), StrideOrder.end(), StrideCompare());
+ for (unsigned i = 0, e = StrideOrder.size(); i != e; ++i) {
+ std::map<SCEVHandle, IVUsersOfOneStride>::iterator SI =
+ IVUsesByStride.find(StrideOrder[i]);
+ if (!isa<SCEVConstant>(SI->first))
+ continue;
+ int64_t SSInt = cast<SCEVConstant>(SI->first)->getValue()->getSExtValue();
+ if (abs(SSInt) <= abs(CmpSSInt) || (SSInt % CmpSSInt) != 0)
+ continue;
+
+ Scale = SSInt / CmpSSInt;
+ NewCmpVal = CmpVal * Scale;
+ APInt Mul = APInt(BitWidth, NewCmpVal);
+ // Check for overflow.
+ if (Mul.getSExtValue() != NewCmpVal) {
+ NewCmpVal = CmpVal;
+ continue;
+ }
+
+ // Watch out for overflow.
+ if (ICmpInst::isSignedPredicate(Predicate) &&
+ (CmpVal & SignBit) != (NewCmpVal & SignBit))
+ NewCmpVal = CmpVal;
+
+ if (NewCmpVal != CmpVal) {
+ // Pick the best iv to use trying to avoid a cast.
+ NewIncV = NULL;
+ for (std::vector<IVStrideUse>::iterator UI = SI->second.Users.begin(),
+ E = SI->second.Users.end(); UI != E; ++UI) {
+ NewIncV = UI->OperandValToReplace;
+ if (NewIncV->getType() == CmpTy)
+ break;
+ }
+ if (!NewIncV) {
+ NewCmpVal = CmpVal;
+ continue;
+ }
+
+ NewCmpTy = NewIncV->getType();
+ NewTyBits = isa<PointerType>(NewCmpTy)
+ ? UIntPtrTy->getPrimitiveSizeInBits()
+ : NewCmpTy->getPrimitiveSizeInBits();
+ if (RequiresTypeConversion(NewCmpTy, CmpTy)) {
+ // Check if it is possible to rewrite it using
+ // an iv / stride of a smaller integer type.
+ bool TruncOk = false;
+ if (NewCmpTy->isInteger()) {
+ unsigned Bits = NewTyBits;
+ if (ICmpInst::isSignedPredicate(Predicate))
+ --Bits;
+ uint64_t Mask = (1ULL << Bits) - 1;
+ if (((uint64_t)NewCmpVal & Mask) == (uint64_t)NewCmpVal)
+ TruncOk = true;
+ }
+ if (!TruncOk) {
+ NewCmpVal = CmpVal;
+ continue;
+ }
+ }
+
+ // Don't rewrite if use offset is non-constant and the new type is
+ // of a different type.
+ // FIXME: too conservative?
+ if (NewTyBits != TyBits && !isa<SCEVConstant>(CondUse->Offset)) {
+ NewCmpVal = CmpVal;
+ continue;
+ }
+
+ bool AllUsesAreAddresses = true;
+ std::vector<BasedUser> UsersToProcess;
+ SCEVHandle CommonExprs = CollectIVUsers(SI->first, SI->second, L,
+ AllUsesAreAddresses,
+ UsersToProcess);
+ // Avoid rewriting the compare instruction with an iv of new stride
+ // if it's likely the new stride uses will be rewritten using the
+ if (AllUsesAreAddresses &&
+ ValidStride(!CommonExprs->isZero(), Scale, UsersToProcess)) {
+ NewCmpVal = CmpVal;
+ continue;
+ }
+
+ // If scale is negative, use inverse predicate unless it's testing
+ // for equality.
+ if (Scale < 0 && !Cond->isEquality())
+ Predicate = ICmpInst::getInversePredicate(Predicate);
+
+ NewStride = &StrideOrder[i];
+ break;
+ }
+ }
+
+ // Forgo this transformation if it the increment happens to be
+ // unfortunately positioned after the condition, and the condition
+ // has multiple uses which prevent it from being moved immediately
+ // before the branch. See
+ // test/Transforms/LoopStrengthReduce/change-compare-stride-trickiness-*.ll
+ // for an example of this situation.
+ if (!Cond->hasOneUse())
+ for (BasicBlock::iterator I = Cond, E = Cond->getParent()->end();
+ I != E; ++I)
+ if (I == NewIncV)
+ return Cond;
+
+ if (NewCmpVal != CmpVal) {
+ // Create a new compare instruction using new stride / iv.
+ ICmpInst *OldCond = Cond;
+ Value *RHS;
+ if (!isa<PointerType>(NewCmpTy))
+ RHS = ConstantInt::get(NewCmpTy, NewCmpVal);
+ else {
+ RHS = ConstantInt::get(UIntPtrTy, NewCmpVal);
+ RHS = SCEVExpander::InsertCastOfTo(Instruction::IntToPtr, RHS, NewCmpTy);
+ }
+ // Insert new compare instruction.
+ Cond = new ICmpInst(Predicate, NewIncV, RHS,
+ L->getHeader()->getName() + ".termcond",
+ OldCond);
+
+ // Remove the old compare instruction. The old indvar is probably dead too.
+ DeadInsts.insert(cast<Instruction>(CondUse->OperandValToReplace));
+ SE->deleteValueFromRecords(OldCond);
+ OldCond->replaceAllUsesWith(Cond);
+ OldCond->eraseFromParent();
+
+ IVUsesByStride[*CondStride].Users.pop_back();
+ SCEVHandle NewOffset = TyBits == NewTyBits
+ ? SE->getMulExpr(CondUse->Offset,
+ SE->getConstant(ConstantInt::get(CmpTy, Scale)))
+ : SE->getConstant(ConstantInt::get(NewCmpTy,
+ cast<SCEVConstant>(CondUse->Offset)->getValue()->getSExtValue()*Scale));
+ IVUsesByStride[*NewStride].addUser(NewOffset, Cond, NewIncV);
+ CondUse = &IVUsesByStride[*NewStride].Users.back();
+ CondStride = NewStride;
+ ++NumEliminated;
+ }
+
+ return Cond;
+}
+