Add a new SCEV representing signed division.
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file contains the implementation of the scalar evolution analysis
11 // engine, which is used primarily to analyze expressions involving induction
12 // variables in loops.
13 //
14 // There are several aspects to this library.  First is the representation of
15 // scalar expressions, which are represented as subclasses of the SCEV class.
16 // These classes are used to represent certain types of subexpressions that we
17 // can handle.  These classes are reference counted, managed by the SCEVHandle
18 // class.  We only create one SCEV of a particular shape, so pointer-comparisons
19 // for equality are legal.
20 //
21 // One important aspect of the SCEV objects is that they are never cyclic, even
22 // if there is a cycle in the dataflow for an expression (ie, a PHI node).  If
23 // the PHI node is one of the idioms that we can represent (e.g., a polynomial
24 // recurrence) then we represent it directly as a recurrence node, otherwise we
25 // represent it as a SCEVUnknown node.
26 //
27 // In addition to being able to represent expressions of various types, we also
28 // have folders that are used to build the *canonical* representation for a
29 // particular expression.  These folders are capable of using a variety of
30 // rewrite rules to simplify the expressions.
31 //
32 // Once the folders are defined, we can implement the more interesting
33 // higher-level code, such as the code that recognizes PHI nodes of various
34 // types, computes the execution count of a loop, etc.
35 //
36 // TODO: We should use these routines and value representations to implement
37 // dependence analysis!
38 //
39 //===----------------------------------------------------------------------===//
40 //
41 // There are several good references for the techniques used in this analysis.
42 //
43 //  Chains of recurrences -- a method to expedite the evaluation
44 //  of closed-form functions
45 //  Olaf Bachmann, Paul S. Wang, Eugene V. Zima
46 //
47 //  On computational properties of chains of recurrences
48 //  Eugene V. Zima
49 //
50 //  Symbolic Evaluation of Chains of Recurrences for Loop Optimization
51 //  Robert A. van Engelen
52 //
53 //  Efficient Symbolic Analysis for Optimizing Compilers
54 //  Robert A. van Engelen
55 //
56 //  Using the chains of recurrences algebra for data dependence testing and
57 //  induction variable substitution
58 //  MS Thesis, Johnie Birch
59 //
60 //===----------------------------------------------------------------------===//
61
62 #define DEBUG_TYPE "scalar-evolution"
63 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
64 #include "llvm/Constants.h"
65 #include "llvm/DerivedTypes.h"
66 #include "llvm/GlobalVariable.h"
67 #include "llvm/Instructions.h"
68 #include "llvm/Analysis/ConstantFolding.h"
69 #include "llvm/Analysis/LoopInfo.h"
70 #include "llvm/Assembly/Writer.h"
71 #include "llvm/Transforms/Scalar.h"
72 #include "llvm/Support/CFG.h"
73 #include "llvm/Support/CommandLine.h"
74 #include "llvm/Support/Compiler.h"
75 #include "llvm/Support/ConstantRange.h"
76 #include "llvm/Support/InstIterator.h"
77 #include "llvm/Support/ManagedStatic.h"
78 #include "llvm/Support/MathExtras.h"
79 #include "llvm/Support/Streams.h"
80 #include "llvm/ADT/Statistic.h"
81 #include <ostream>
82 #include <algorithm>
83 #include <cmath>
84 using namespace llvm;
85
86 STATISTIC(NumArrayLenItCounts,
87           "Number of trip counts computed with array length");
88 STATISTIC(NumTripCountsComputed,
89           "Number of loops with predictable loop counts");
90 STATISTIC(NumTripCountsNotComputed,
91           "Number of loops without predictable loop counts");
92 STATISTIC(NumBruteForceTripCountsComputed,
93           "Number of loops with trip counts computed by force");
94
95 static cl::opt<unsigned>
96 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
97                         cl::desc("Maximum number of iterations SCEV will "
98                                  "symbolically execute a constant derived loop"),
99                         cl::init(100));
100
101 static RegisterPass<ScalarEvolution>
102 R("scalar-evolution", "Scalar Evolution Analysis", false, true);
103 char ScalarEvolution::ID = 0;
104
105 //===----------------------------------------------------------------------===//
106 //                           SCEV class definitions
107 //===----------------------------------------------------------------------===//
108
109 //===----------------------------------------------------------------------===//
110 // Implementation of the SCEV class.
111 //
112 SCEV::~SCEV() {}
113 void SCEV::dump() const {
114   print(cerr);
115 }
116
117 uint32_t SCEV::getBitWidth() const {
118   if (const IntegerType* ITy = dyn_cast<IntegerType>(getType()))
119     return ITy->getBitWidth();
120   return 0;
121 }
122
123 bool SCEV::isZero() const {
124   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
125     return SC->getValue()->isZero();
126   return false;
127 }
128
129
130 SCEVCouldNotCompute::SCEVCouldNotCompute() : SCEV(scCouldNotCompute) {}
131
132 bool SCEVCouldNotCompute::isLoopInvariant(const Loop *L) const {
133   assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
134   return false;
135 }
136
137 const Type *SCEVCouldNotCompute::getType() const {
138   assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
139   return 0;
140 }
141
142 bool SCEVCouldNotCompute::hasComputableLoopEvolution(const Loop *L) const {
143   assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
144   return false;
145 }
146
147 SCEVHandle SCEVCouldNotCompute::
148 replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
149                                   const SCEVHandle &Conc,
150                                   ScalarEvolution &SE) const {
151   return this;
152 }
153
154 void SCEVCouldNotCompute::print(std::ostream &OS) const {
155   OS << "***COULDNOTCOMPUTE***";
156 }
157
158 bool SCEVCouldNotCompute::classof(const SCEV *S) {
159   return S->getSCEVType() == scCouldNotCompute;
160 }
161
162
163 // SCEVConstants - Only allow the creation of one SCEVConstant for any
164 // particular value.  Don't use a SCEVHandle here, or else the object will
165 // never be deleted!
166 static ManagedStatic<std::map<ConstantInt*, SCEVConstant*> > SCEVConstants;
167
168
169 SCEVConstant::~SCEVConstant() {
170   SCEVConstants->erase(V);
171 }
172
173 SCEVHandle ScalarEvolution::getConstant(ConstantInt *V) {
174   SCEVConstant *&R = (*SCEVConstants)[V];
175   if (R == 0) R = new SCEVConstant(V);
176   return R;
177 }
178
179 SCEVHandle ScalarEvolution::getConstant(const APInt& Val) {
180   return getConstant(ConstantInt::get(Val));
181 }
182
183 const Type *SCEVConstant::getType() const { return V->getType(); }
184
185 void SCEVConstant::print(std::ostream &OS) const {
186   WriteAsOperand(OS, V, false);
187 }
188
189 // SCEVTruncates - Only allow the creation of one SCEVTruncateExpr for any
190 // particular input.  Don't use a SCEVHandle here, or else the object will
191 // never be deleted!
192 static ManagedStatic<std::map<std::pair<SCEV*, const Type*>, 
193                      SCEVTruncateExpr*> > SCEVTruncates;
194
195 SCEVTruncateExpr::SCEVTruncateExpr(const SCEVHandle &op, const Type *ty)
196   : SCEV(scTruncate), Op(op), Ty(ty) {
197   assert(Op->getType()->isInteger() && Ty->isInteger() &&
198          "Cannot truncate non-integer value!");
199   assert(Op->getType()->getPrimitiveSizeInBits() > Ty->getPrimitiveSizeInBits()
200          && "This is not a truncating conversion!");
201 }
202
203 SCEVTruncateExpr::~SCEVTruncateExpr() {
204   SCEVTruncates->erase(std::make_pair(Op, Ty));
205 }
206
207 void SCEVTruncateExpr::print(std::ostream &OS) const {
208   OS << "(truncate " << *Op << " to " << *Ty << ")";
209 }
210
211 // SCEVZeroExtends - Only allow the creation of one SCEVZeroExtendExpr for any
212 // particular input.  Don't use a SCEVHandle here, or else the object will never
213 // be deleted!
214 static ManagedStatic<std::map<std::pair<SCEV*, const Type*>,
215                      SCEVZeroExtendExpr*> > SCEVZeroExtends;
216
217 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const SCEVHandle &op, const Type *ty)
218   : SCEV(scZeroExtend), Op(op), Ty(ty) {
219   assert(Op->getType()->isInteger() && Ty->isInteger() &&
220          "Cannot zero extend non-integer value!");
221   assert(Op->getType()->getPrimitiveSizeInBits() < Ty->getPrimitiveSizeInBits()
222          && "This is not an extending conversion!");
223 }
224
225 SCEVZeroExtendExpr::~SCEVZeroExtendExpr() {
226   SCEVZeroExtends->erase(std::make_pair(Op, Ty));
227 }
228
229 void SCEVZeroExtendExpr::print(std::ostream &OS) const {
230   OS << "(zeroextend " << *Op << " to " << *Ty << ")";
231 }
232
233 // SCEVSignExtends - Only allow the creation of one SCEVSignExtendExpr for any
234 // particular input.  Don't use a SCEVHandle here, or else the object will never
235 // be deleted!
236 static ManagedStatic<std::map<std::pair<SCEV*, const Type*>,
237                      SCEVSignExtendExpr*> > SCEVSignExtends;
238
239 SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEVHandle &op, const Type *ty)
240   : SCEV(scSignExtend), Op(op), Ty(ty) {
241   assert(Op->getType()->isInteger() && Ty->isInteger() &&
242          "Cannot sign extend non-integer value!");
243   assert(Op->getType()->getPrimitiveSizeInBits() < Ty->getPrimitiveSizeInBits()
244          && "This is not an extending conversion!");
245 }
246
247 SCEVSignExtendExpr::~SCEVSignExtendExpr() {
248   SCEVSignExtends->erase(std::make_pair(Op, Ty));
249 }
250
251 void SCEVSignExtendExpr::print(std::ostream &OS) const {
252   OS << "(signextend " << *Op << " to " << *Ty << ")";
253 }
254
255 // SCEVCommExprs - Only allow the creation of one SCEVCommutativeExpr for any
256 // particular input.  Don't use a SCEVHandle here, or else the object will never
257 // be deleted!
258 static ManagedStatic<std::map<std::pair<unsigned, std::vector<SCEV*> >,
259                      SCEVCommutativeExpr*> > SCEVCommExprs;
260
261 SCEVCommutativeExpr::~SCEVCommutativeExpr() {
262   SCEVCommExprs->erase(std::make_pair(getSCEVType(),
263                                       std::vector<SCEV*>(Operands.begin(),
264                                                          Operands.end())));
265 }
266
267 void SCEVCommutativeExpr::print(std::ostream &OS) const {
268   assert(Operands.size() > 1 && "This plus expr shouldn't exist!");
269   const char *OpStr = getOperationStr();
270   OS << "(" << *Operands[0];
271   for (unsigned i = 1, e = Operands.size(); i != e; ++i)
272     OS << OpStr << *Operands[i];
273   OS << ")";
274 }
275
276 SCEVHandle SCEVCommutativeExpr::
277 replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
278                                   const SCEVHandle &Conc,
279                                   ScalarEvolution &SE) const {
280   for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
281     SCEVHandle H =
282       getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE);
283     if (H != getOperand(i)) {
284       std::vector<SCEVHandle> NewOps;
285       NewOps.reserve(getNumOperands());
286       for (unsigned j = 0; j != i; ++j)
287         NewOps.push_back(getOperand(j));
288       NewOps.push_back(H);
289       for (++i; i != e; ++i)
290         NewOps.push_back(getOperand(i)->
291                          replaceSymbolicValuesWithConcrete(Sym, Conc, SE));
292
293       if (isa<SCEVAddExpr>(this))
294         return SE.getAddExpr(NewOps);
295       else if (isa<SCEVMulExpr>(this))
296         return SE.getMulExpr(NewOps);
297       else if (isa<SCEVSMaxExpr>(this))
298         return SE.getSMaxExpr(NewOps);
299       else if (isa<SCEVUMaxExpr>(this))
300         return SE.getUMaxExpr(NewOps);
301       else
302         assert(0 && "Unknown commutative expr!");
303     }
304   }
305   return this;
306 }
307
308
309 // SCEVUDivs - Only allow the creation of one SCEVUDivExpr for any particular
310 // input.  Don't use a SCEVHandle here, or else the object will never be
311 // deleted!
312 static ManagedStatic<std::map<std::pair<SCEV*, SCEV*>, 
313                      SCEVUDivExpr*> > SCEVUDivs;
314
315 SCEVUDivExpr::~SCEVUDivExpr() {
316   SCEVUDivs->erase(std::make_pair(LHS, RHS));
317 }
318
319 void SCEVUDivExpr::print(std::ostream &OS) const {
320   OS << "(" << *LHS << " /u " << *RHS << ")";
321 }
322
323 const Type *SCEVUDivExpr::getType() const {
324   return LHS->getType();
325 }
326
327
328 // SCEVSDivs - Only allow the creation of one SCEVSDivExpr for any particular
329 // input.  Don't use a SCEVHandle here, or else the object will never be
330 // deleted!
331 static ManagedStatic<std::map<std::pair<SCEV*, SCEV*>, 
332                      SCEVSDivExpr*> > SCEVSDivs;
333
334 SCEVSDivExpr::~SCEVSDivExpr() {
335   SCEVSDivs->erase(std::make_pair(LHS, RHS));
336 }
337
338 void SCEVSDivExpr::print(std::ostream &OS) const {
339   OS << "(" << *LHS << " /s " << *RHS << ")";
340 }
341
342 const Type *SCEVSDivExpr::getType() const {
343   return LHS->getType();
344 }
345
346
347 // SCEVAddRecExprs - Only allow the creation of one SCEVAddRecExpr for any
348 // particular input.  Don't use a SCEVHandle here, or else the object will never
349 // be deleted!
350 static ManagedStatic<std::map<std::pair<const Loop *, std::vector<SCEV*> >,
351                      SCEVAddRecExpr*> > SCEVAddRecExprs;
352
353 SCEVAddRecExpr::~SCEVAddRecExpr() {
354   SCEVAddRecExprs->erase(std::make_pair(L,
355                                         std::vector<SCEV*>(Operands.begin(),
356                                                            Operands.end())));
357 }
358
359 SCEVHandle SCEVAddRecExpr::
360 replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
361                                   const SCEVHandle &Conc,
362                                   ScalarEvolution &SE) const {
363   for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
364     SCEVHandle H =
365       getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE);
366     if (H != getOperand(i)) {
367       std::vector<SCEVHandle> NewOps;
368       NewOps.reserve(getNumOperands());
369       for (unsigned j = 0; j != i; ++j)
370         NewOps.push_back(getOperand(j));
371       NewOps.push_back(H);
372       for (++i; i != e; ++i)
373         NewOps.push_back(getOperand(i)->
374                          replaceSymbolicValuesWithConcrete(Sym, Conc, SE));
375
376       return SE.getAddRecExpr(NewOps, L);
377     }
378   }
379   return this;
380 }
381
382
383 bool SCEVAddRecExpr::isLoopInvariant(const Loop *QueryLoop) const {
384   // This recurrence is invariant w.r.t to QueryLoop iff QueryLoop doesn't
385   // contain L and if the start is invariant.
386   return !QueryLoop->contains(L->getHeader()) &&
387          getOperand(0)->isLoopInvariant(QueryLoop);
388 }
389
390
391 void SCEVAddRecExpr::print(std::ostream &OS) const {
392   OS << "{" << *Operands[0];
393   for (unsigned i = 1, e = Operands.size(); i != e; ++i)
394     OS << ",+," << *Operands[i];
395   OS << "}<" << L->getHeader()->getName() + ">";
396 }
397
398 // SCEVUnknowns - Only allow the creation of one SCEVUnknown for any particular
399 // value.  Don't use a SCEVHandle here, or else the object will never be
400 // deleted!
401 static ManagedStatic<std::map<Value*, SCEVUnknown*> > SCEVUnknowns;
402
403 SCEVUnknown::~SCEVUnknown() { SCEVUnknowns->erase(V); }
404
405 bool SCEVUnknown::isLoopInvariant(const Loop *L) const {
406   // All non-instruction values are loop invariant.  All instructions are loop
407   // invariant if they are not contained in the specified loop.
408   if (Instruction *I = dyn_cast<Instruction>(V))
409     return !L->contains(I->getParent());
410   return true;
411 }
412
413 const Type *SCEVUnknown::getType() const {
414   return V->getType();
415 }
416
417 void SCEVUnknown::print(std::ostream &OS) const {
418   WriteAsOperand(OS, V, false);
419 }
420
421 //===----------------------------------------------------------------------===//
422 //                               SCEV Utilities
423 //===----------------------------------------------------------------------===//
424
425 namespace {
426   /// SCEVComplexityCompare - Return true if the complexity of the LHS is less
427   /// than the complexity of the RHS.  This comparator is used to canonicalize
428   /// expressions.
429   struct VISIBILITY_HIDDEN SCEVComplexityCompare {
430     bool operator()(const SCEV *LHS, const SCEV *RHS) const {
431       return LHS->getSCEVType() < RHS->getSCEVType();
432     }
433   };
434 }
435
436 /// GroupByComplexity - Given a list of SCEV objects, order them by their
437 /// complexity, and group objects of the same complexity together by value.
438 /// When this routine is finished, we know that any duplicates in the vector are
439 /// consecutive and that complexity is monotonically increasing.
440 ///
441 /// Note that we go take special precautions to ensure that we get determinstic
442 /// results from this routine.  In other words, we don't want the results of
443 /// this to depend on where the addresses of various SCEV objects happened to
444 /// land in memory.
445 ///
446 static void GroupByComplexity(std::vector<SCEVHandle> &Ops) {
447   if (Ops.size() < 2) return;  // Noop
448   if (Ops.size() == 2) {
449     // This is the common case, which also happens to be trivially simple.
450     // Special case it.
451     if (SCEVComplexityCompare()(Ops[1], Ops[0]))
452       std::swap(Ops[0], Ops[1]);
453     return;
454   }
455
456   // Do the rough sort by complexity.
457   std::sort(Ops.begin(), Ops.end(), SCEVComplexityCompare());
458
459   // Now that we are sorted by complexity, group elements of the same
460   // complexity.  Note that this is, at worst, N^2, but the vector is likely to
461   // be extremely short in practice.  Note that we take this approach because we
462   // do not want to depend on the addresses of the objects we are grouping.
463   for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
464     SCEV *S = Ops[i];
465     unsigned Complexity = S->getSCEVType();
466
467     // If there are any objects of the same complexity and same value as this
468     // one, group them.
469     for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
470       if (Ops[j] == S) { // Found a duplicate.
471         // Move it to immediately after i'th element.
472         std::swap(Ops[i+1], Ops[j]);
473         ++i;   // no need to rescan it.
474         if (i == e-2) return;  // Done!
475       }
476     }
477   }
478 }
479
480
481
482 //===----------------------------------------------------------------------===//
483 //                      Simple SCEV method implementations
484 //===----------------------------------------------------------------------===//
485
486 /// getIntegerSCEV - Given an integer or FP type, create a constant for the
487 /// specified signed integer value and return a SCEV for the constant.
488 SCEVHandle ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) {
489   Constant *C;
490   if (Val == 0)
491     C = Constant::getNullValue(Ty);
492   else if (Ty->isFloatingPoint())
493     C = ConstantFP::get(APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle : 
494                                 APFloat::IEEEdouble, Val));
495   else 
496     C = ConstantInt::get(Ty, Val);
497   return getUnknown(C);
498 }
499
500 /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
501 ///
502 SCEVHandle ScalarEvolution::getNegativeSCEV(const SCEVHandle &V) {
503   if (SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
504     return getUnknown(ConstantExpr::getNeg(VC->getValue()));
505
506   return getMulExpr(V, getConstant(ConstantInt::getAllOnesValue(V->getType())));
507 }
508
509 /// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
510 SCEVHandle ScalarEvolution::getNotSCEV(const SCEVHandle &V) {
511   if (SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
512     return getUnknown(ConstantExpr::getNot(VC->getValue()));
513
514   SCEVHandle AllOnes = getConstant(ConstantInt::getAllOnesValue(V->getType()));
515   return getMinusSCEV(AllOnes, V);
516 }
517
518 /// getMinusSCEV - Return a SCEV corresponding to LHS - RHS.
519 ///
520 SCEVHandle ScalarEvolution::getMinusSCEV(const SCEVHandle &LHS,
521                                          const SCEVHandle &RHS) {
522   // X - Y --> X + -Y
523   return getAddExpr(LHS, getNegativeSCEV(RHS));
524 }
525
526
527 /// BinomialCoefficient - Compute BC(It, K).  The result has width W.
528 // Assume, K > 0.
529 static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K,
530                                       ScalarEvolution &SE,
531                                       const IntegerType* ResultTy) {
532   // Handle the simplest case efficiently.
533   if (K == 1)
534     return SE.getTruncateOrZeroExtend(It, ResultTy);
535
536   // We are using the following formula for BC(It, K):
537   //
538   //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
539   //
540   // Suppose, W is the bitwidth of the return value.  We must be prepared for
541   // overflow.  Hence, we must assure that the result of our computation is
542   // equal to the accurate one modulo 2^W.  Unfortunately, division isn't
543   // safe in modular arithmetic.
544   //
545   // However, this code doesn't use exactly that formula; the formula it uses
546   // is something like the following, where T is the number of factors of 2 in 
547   // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
548   // exponentiation:
549   //
550   //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
551   //
552   // This formula is trivially equivalent to the previous formula.  However,
553   // this formula can be implemented much more efficiently.  The trick is that
554   // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
555   // arithmetic.  To do exact division in modular arithmetic, all we have
556   // to do is multiply by the inverse.  Therefore, this step can be done at
557   // width W.
558   // 
559   // The next issue is how to safely do the division by 2^T.  The way this
560   // is done is by doing the multiplication step at a width of at least W + T
561   // bits.  This way, the bottom W+T bits of the product are accurate. Then,
562   // when we perform the division by 2^T (which is equivalent to a right shift
563   // by T), the bottom W bits are accurate.  Extra bits are okay; they'll get
564   // truncated out after the division by 2^T.
565   //
566   // In comparison to just directly using the first formula, this technique
567   // is much more efficient; using the first formula requires W * K bits,
568   // but this formula less than W + K bits. Also, the first formula requires
569   // a division step, whereas this formula only requires multiplies and shifts.
570   //
571   // It doesn't matter whether the subtraction step is done in the calculation
572   // width or the input iteration count's width; if the subtraction overflows,
573   // the result must be zero anyway.  We prefer here to do it in the width of
574   // the induction variable because it helps a lot for certain cases; CodeGen
575   // isn't smart enough to ignore the overflow, which leads to much less
576   // efficient code if the width of the subtraction is wider than the native
577   // register width.
578   //
579   // (It's possible to not widen at all by pulling out factors of 2 before
580   // the multiplication; for example, K=2 can be calculated as
581   // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
582   // extra arithmetic, so it's not an obvious win, and it gets
583   // much more complicated for K > 3.)
584
585   // Protection from insane SCEVs; this bound is conservative,
586   // but it probably doesn't matter.
587   if (K > 1000)
588     return new SCEVCouldNotCompute();
589
590   unsigned W = ResultTy->getBitWidth();
591
592   // Calculate K! / 2^T and T; we divide out the factors of two before
593   // multiplying for calculating K! / 2^T to avoid overflow.
594   // Other overflow doesn't matter because we only care about the bottom
595   // W bits of the result.
596   APInt OddFactorial(W, 1);
597   unsigned T = 1;
598   for (unsigned i = 3; i <= K; ++i) {
599     APInt Mult(W, i);
600     unsigned TwoFactors = Mult.countTrailingZeros();
601     T += TwoFactors;
602     Mult = Mult.lshr(TwoFactors);
603     OddFactorial *= Mult;
604   }
605
606   // We need at least W + T bits for the multiplication step
607   // FIXME: A temporary hack; we round up the bitwidths
608   // to the nearest power of 2 to be nice to the code generator.
609   unsigned CalculationBits = 1U << Log2_32_Ceil(W + T);
610   // FIXME: Temporary hack to avoid generating integers that are too wide.
611   // Although, it's not completely clear how to determine how much
612   // widening is safe; for example, on X86, we can't really widen
613   // beyond 64 because we need to be able to do multiplication
614   // that's CalculationBits wide, but on X86-64, we can safely widen up to
615   // 128 bits.
616   if (CalculationBits > 64)
617     return new SCEVCouldNotCompute();
618
619   // Calcuate 2^T, at width T+W.
620   APInt DivFactor = APInt(CalculationBits, 1).shl(T);
621
622   // Calculate the multiplicative inverse of K! / 2^T;
623   // this multiplication factor will perform the exact division by
624   // K! / 2^T.
625   APInt Mod = APInt::getSignedMinValue(W+1);
626   APInt MultiplyFactor = OddFactorial.zext(W+1);
627   MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
628   MultiplyFactor = MultiplyFactor.trunc(W);
629
630   // Calculate the product, at width T+W
631   const IntegerType *CalculationTy = IntegerType::get(CalculationBits);
632   SCEVHandle Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
633   for (unsigned i = 1; i != K; ++i) {
634     SCEVHandle S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType()));
635     Dividend = SE.getMulExpr(Dividend,
636                              SE.getTruncateOrZeroExtend(S, CalculationTy));
637   }
638
639   // Divide by 2^T
640   SCEVHandle DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
641
642   // Truncate the result, and divide by K! / 2^T.
643
644   return SE.getMulExpr(SE.getConstant(MultiplyFactor),
645                        SE.getTruncateOrZeroExtend(DivResult, ResultTy));
646 }
647
648 /// evaluateAtIteration - Return the value of this chain of recurrences at
649 /// the specified iteration number.  We can evaluate this recurrence by
650 /// multiplying each element in the chain by the binomial coefficient
651 /// corresponding to it.  In other words, we can evaluate {A,+,B,+,C,+,D} as:
652 ///
653 ///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
654 ///
655 /// where BC(It, k) stands for binomial coefficient.
656 ///
657 SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It,
658                                                ScalarEvolution &SE) const {
659   SCEVHandle Result = getStart();
660   for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
661     // The computation is correct in the face of overflow provided that the
662     // multiplication is performed _after_ the evaluation of the binomial
663     // coefficient.
664     SCEVHandle Coeff = BinomialCoefficient(It, i, SE,
665                                            cast<IntegerType>(getType()));
666     if (isa<SCEVCouldNotCompute>(Coeff))
667       return Coeff;
668
669     Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
670   }
671   return Result;
672 }
673
674 //===----------------------------------------------------------------------===//
675 //                    SCEV Expression folder implementations
676 //===----------------------------------------------------------------------===//
677
678 SCEVHandle ScalarEvolution::getTruncateExpr(const SCEVHandle &Op, const Type *Ty) {
679   if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
680     return getUnknown(
681         ConstantExpr::getTrunc(SC->getValue(), Ty));
682
683   // If the input value is a chrec scev made out of constants, truncate
684   // all of the constants.
685   if (SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
686     std::vector<SCEVHandle> Operands;
687     for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
688       // FIXME: This should allow truncation of other expression types!
689       if (isa<SCEVConstant>(AddRec->getOperand(i)))
690         Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty));
691       else
692         break;
693     if (Operands.size() == AddRec->getNumOperands())
694       return getAddRecExpr(Operands, AddRec->getLoop());
695   }
696
697   SCEVTruncateExpr *&Result = (*SCEVTruncates)[std::make_pair(Op, Ty)];
698   if (Result == 0) Result = new SCEVTruncateExpr(Op, Ty);
699   return Result;
700 }
701
702 SCEVHandle ScalarEvolution::getZeroExtendExpr(const SCEVHandle &Op, const Type *Ty) {
703   if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
704     return getUnknown(
705         ConstantExpr::getZExt(SC->getValue(), Ty));
706
707   // FIXME: If the input value is a chrec scev, and we can prove that the value
708   // did not overflow the old, smaller, value, we can zero extend all of the
709   // operands (often constants).  This would allow analysis of something like
710   // this:  for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
711
712   SCEVZeroExtendExpr *&Result = (*SCEVZeroExtends)[std::make_pair(Op, Ty)];
713   if (Result == 0) Result = new SCEVZeroExtendExpr(Op, Ty);
714   return Result;
715 }
716
717 SCEVHandle ScalarEvolution::getSignExtendExpr(const SCEVHandle &Op, const Type *Ty) {
718   if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
719     return getUnknown(
720         ConstantExpr::getSExt(SC->getValue(), Ty));
721
722   // FIXME: If the input value is a chrec scev, and we can prove that the value
723   // did not overflow the old, smaller, value, we can sign extend all of the
724   // operands (often constants).  This would allow analysis of something like
725   // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
726
727   SCEVSignExtendExpr *&Result = (*SCEVSignExtends)[std::make_pair(Op, Ty)];
728   if (Result == 0) Result = new SCEVSignExtendExpr(Op, Ty);
729   return Result;
730 }
731
732 /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion
733 /// of the input value to the specified type.  If the type must be
734 /// extended, it is zero extended.
735 SCEVHandle ScalarEvolution::getTruncateOrZeroExtend(const SCEVHandle &V,
736                                                     const Type *Ty) {
737   const Type *SrcTy = V->getType();
738   assert(SrcTy->isInteger() && Ty->isInteger() &&
739          "Cannot truncate or zero extend with non-integer arguments!");
740   if (SrcTy->getPrimitiveSizeInBits() == Ty->getPrimitiveSizeInBits())
741     return V;  // No conversion
742   if (SrcTy->getPrimitiveSizeInBits() > Ty->getPrimitiveSizeInBits())
743     return getTruncateExpr(V, Ty);
744   return getZeroExtendExpr(V, Ty);
745 }
746
747 // get - Get a canonical add expression, or something simpler if possible.
748 SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) {
749   assert(!Ops.empty() && "Cannot get empty add!");
750   if (Ops.size() == 1) return Ops[0];
751
752   // Sort by complexity, this groups all similar expression types together.
753   GroupByComplexity(Ops);
754
755   // If there are any constants, fold them together.
756   unsigned Idx = 0;
757   if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
758     ++Idx;
759     assert(Idx < Ops.size());
760     while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
761       // We found two constants, fold them together!
762       ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() + 
763                                            RHSC->getValue()->getValue());
764       Ops[0] = getConstant(Fold);
765       Ops.erase(Ops.begin()+1);  // Erase the folded element
766       if (Ops.size() == 1) return Ops[0];
767       LHSC = cast<SCEVConstant>(Ops[0]);
768     }
769
770     // If we are left with a constant zero being added, strip it off.
771     if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
772       Ops.erase(Ops.begin());
773       --Idx;
774     }
775   }
776
777   if (Ops.size() == 1) return Ops[0];
778
779   // Okay, check to see if the same value occurs in the operand list twice.  If
780   // so, merge them together into an multiply expression.  Since we sorted the
781   // list, these values are required to be adjacent.
782   const Type *Ty = Ops[0]->getType();
783   for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
784     if (Ops[i] == Ops[i+1]) {      //  X + Y + Y  -->  X + Y*2
785       // Found a match, merge the two values into a multiply, and add any
786       // remaining values to the result.
787       SCEVHandle Two = getIntegerSCEV(2, Ty);
788       SCEVHandle Mul = getMulExpr(Ops[i], Two);
789       if (Ops.size() == 2)
790         return Mul;
791       Ops.erase(Ops.begin()+i, Ops.begin()+i+2);
792       Ops.push_back(Mul);
793       return getAddExpr(Ops);
794     }
795
796   // Now we know the first non-constant operand.  Skip past any cast SCEVs.
797   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
798     ++Idx;
799
800   // If there are add operands they would be next.
801   if (Idx < Ops.size()) {
802     bool DeletedAdd = false;
803     while (SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
804       // If we have an add, expand the add operands onto the end of the operands
805       // list.
806       Ops.insert(Ops.end(), Add->op_begin(), Add->op_end());
807       Ops.erase(Ops.begin()+Idx);
808       DeletedAdd = true;
809     }
810
811     // If we deleted at least one add, we added operands to the end of the list,
812     // and they are not necessarily sorted.  Recurse to resort and resimplify
813     // any operands we just aquired.
814     if (DeletedAdd)
815       return getAddExpr(Ops);
816   }
817
818   // Skip over the add expression until we get to a multiply.
819   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
820     ++Idx;
821
822   // If we are adding something to a multiply expression, make sure the
823   // something is not already an operand of the multiply.  If so, merge it into
824   // the multiply.
825   for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
826     SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
827     for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
828       SCEV *MulOpSCEV = Mul->getOperand(MulOp);
829       for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
830         if (MulOpSCEV == Ops[AddOp] && !isa<SCEVConstant>(MulOpSCEV)) {
831           // Fold W + X + (X * Y * Z)  -->  W + (X * ((Y*Z)+1))
832           SCEVHandle InnerMul = Mul->getOperand(MulOp == 0);
833           if (Mul->getNumOperands() != 2) {
834             // If the multiply has more than two operands, we must get the
835             // Y*Z term.
836             std::vector<SCEVHandle> MulOps(Mul->op_begin(), Mul->op_end());
837             MulOps.erase(MulOps.begin()+MulOp);
838             InnerMul = getMulExpr(MulOps);
839           }
840           SCEVHandle One = getIntegerSCEV(1, Ty);
841           SCEVHandle AddOne = getAddExpr(InnerMul, One);
842           SCEVHandle OuterMul = getMulExpr(AddOne, Ops[AddOp]);
843           if (Ops.size() == 2) return OuterMul;
844           if (AddOp < Idx) {
845             Ops.erase(Ops.begin()+AddOp);
846             Ops.erase(Ops.begin()+Idx-1);
847           } else {
848             Ops.erase(Ops.begin()+Idx);
849             Ops.erase(Ops.begin()+AddOp-1);
850           }
851           Ops.push_back(OuterMul);
852           return getAddExpr(Ops);
853         }
854
855       // Check this multiply against other multiplies being added together.
856       for (unsigned OtherMulIdx = Idx+1;
857            OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
858            ++OtherMulIdx) {
859         SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
860         // If MulOp occurs in OtherMul, we can fold the two multiplies
861         // together.
862         for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
863              OMulOp != e; ++OMulOp)
864           if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
865             // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
866             SCEVHandle InnerMul1 = Mul->getOperand(MulOp == 0);
867             if (Mul->getNumOperands() != 2) {
868               std::vector<SCEVHandle> MulOps(Mul->op_begin(), Mul->op_end());
869               MulOps.erase(MulOps.begin()+MulOp);
870               InnerMul1 = getMulExpr(MulOps);
871             }
872             SCEVHandle InnerMul2 = OtherMul->getOperand(OMulOp == 0);
873             if (OtherMul->getNumOperands() != 2) {
874               std::vector<SCEVHandle> MulOps(OtherMul->op_begin(),
875                                              OtherMul->op_end());
876               MulOps.erase(MulOps.begin()+OMulOp);
877               InnerMul2 = getMulExpr(MulOps);
878             }
879             SCEVHandle InnerMulSum = getAddExpr(InnerMul1,InnerMul2);
880             SCEVHandle OuterMul = getMulExpr(MulOpSCEV, InnerMulSum);
881             if (Ops.size() == 2) return OuterMul;
882             Ops.erase(Ops.begin()+Idx);
883             Ops.erase(Ops.begin()+OtherMulIdx-1);
884             Ops.push_back(OuterMul);
885             return getAddExpr(Ops);
886           }
887       }
888     }
889   }
890
891   // If there are any add recurrences in the operands list, see if any other
892   // added values are loop invariant.  If so, we can fold them into the
893   // recurrence.
894   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
895     ++Idx;
896
897   // Scan over all recurrences, trying to fold loop invariants into them.
898   for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
899     // Scan all of the other operands to this add and add them to the vector if
900     // they are loop invariant w.r.t. the recurrence.
901     std::vector<SCEVHandle> LIOps;
902     SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
903     for (unsigned i = 0, e = Ops.size(); i != e; ++i)
904       if (Ops[i]->isLoopInvariant(AddRec->getLoop())) {
905         LIOps.push_back(Ops[i]);
906         Ops.erase(Ops.begin()+i);
907         --i; --e;
908       }
909
910     // If we found some loop invariants, fold them into the recurrence.
911     if (!LIOps.empty()) {
912       //  NLI + LI + {Start,+,Step}  -->  NLI + {LI+Start,+,Step}
913       LIOps.push_back(AddRec->getStart());
914
915       std::vector<SCEVHandle> AddRecOps(AddRec->op_begin(), AddRec->op_end());
916       AddRecOps[0] = getAddExpr(LIOps);
917
918       SCEVHandle NewRec = getAddRecExpr(AddRecOps, AddRec->getLoop());
919       // If all of the other operands were loop invariant, we are done.
920       if (Ops.size() == 1) return NewRec;
921
922       // Otherwise, add the folded AddRec by the non-liv parts.
923       for (unsigned i = 0;; ++i)
924         if (Ops[i] == AddRec) {
925           Ops[i] = NewRec;
926           break;
927         }
928       return getAddExpr(Ops);
929     }
930
931     // Okay, if there weren't any loop invariants to be folded, check to see if
932     // there are multiple AddRec's with the same loop induction variable being
933     // added together.  If so, we can fold them.
934     for (unsigned OtherIdx = Idx+1;
935          OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx)
936       if (OtherIdx != Idx) {
937         SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
938         if (AddRec->getLoop() == OtherAddRec->getLoop()) {
939           // Other + {A,+,B} + {C,+,D}  -->  Other + {A+C,+,B+D}
940           std::vector<SCEVHandle> NewOps(AddRec->op_begin(), AddRec->op_end());
941           for (unsigned i = 0, e = OtherAddRec->getNumOperands(); i != e; ++i) {
942             if (i >= NewOps.size()) {
943               NewOps.insert(NewOps.end(), OtherAddRec->op_begin()+i,
944                             OtherAddRec->op_end());
945               break;
946             }
947             NewOps[i] = getAddExpr(NewOps[i], OtherAddRec->getOperand(i));
948           }
949           SCEVHandle NewAddRec = getAddRecExpr(NewOps, AddRec->getLoop());
950
951           if (Ops.size() == 2) return NewAddRec;
952
953           Ops.erase(Ops.begin()+Idx);
954           Ops.erase(Ops.begin()+OtherIdx-1);
955           Ops.push_back(NewAddRec);
956           return getAddExpr(Ops);
957         }
958       }
959
960     // Otherwise couldn't fold anything into this recurrence.  Move onto the
961     // next one.
962   }
963
964   // Okay, it looks like we really DO need an add expr.  Check to see if we
965   // already have one, otherwise create a new one.
966   std::vector<SCEV*> SCEVOps(Ops.begin(), Ops.end());
967   SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scAddExpr,
968                                                                  SCEVOps)];
969   if (Result == 0) Result = new SCEVAddExpr(Ops);
970   return Result;
971 }
972
973
974 SCEVHandle ScalarEvolution::getMulExpr(std::vector<SCEVHandle> &Ops) {
975   assert(!Ops.empty() && "Cannot get empty mul!");
976
977   // Sort by complexity, this groups all similar expression types together.
978   GroupByComplexity(Ops);
979
980   // If there are any constants, fold them together.
981   unsigned Idx = 0;
982   if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
983
984     // C1*(C2+V) -> C1*C2 + C1*V
985     if (Ops.size() == 2)
986       if (SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
987         if (Add->getNumOperands() == 2 &&
988             isa<SCEVConstant>(Add->getOperand(0)))
989           return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
990                             getMulExpr(LHSC, Add->getOperand(1)));
991
992
993     ++Idx;
994     while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
995       // We found two constants, fold them together!
996       ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() * 
997                                            RHSC->getValue()->getValue());
998       Ops[0] = getConstant(Fold);
999       Ops.erase(Ops.begin()+1);  // Erase the folded element
1000       if (Ops.size() == 1) return Ops[0];
1001       LHSC = cast<SCEVConstant>(Ops[0]);
1002     }
1003
1004     // If we are left with a constant one being multiplied, strip it off.
1005     if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) {
1006       Ops.erase(Ops.begin());
1007       --Idx;
1008     } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
1009       // If we have a multiply of zero, it will always be zero.
1010       return Ops[0];
1011     }
1012   }
1013
1014   // Skip over the add expression until we get to a multiply.
1015   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1016     ++Idx;
1017
1018   if (Ops.size() == 1)
1019     return Ops[0];
1020
1021   // If there are mul operands inline them all into this expression.
1022   if (Idx < Ops.size()) {
1023     bool DeletedMul = false;
1024     while (SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
1025       // If we have an mul, expand the mul operands onto the end of the operands
1026       // list.
1027       Ops.insert(Ops.end(), Mul->op_begin(), Mul->op_end());
1028       Ops.erase(Ops.begin()+Idx);
1029       DeletedMul = true;
1030     }
1031
1032     // If we deleted at least one mul, we added operands to the end of the list,
1033     // and they are not necessarily sorted.  Recurse to resort and resimplify
1034     // any operands we just aquired.
1035     if (DeletedMul)
1036       return getMulExpr(Ops);
1037   }
1038
1039   // If there are any add recurrences in the operands list, see if any other
1040   // added values are loop invariant.  If so, we can fold them into the
1041   // recurrence.
1042   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1043     ++Idx;
1044
1045   // Scan over all recurrences, trying to fold loop invariants into them.
1046   for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1047     // Scan all of the other operands to this mul and add them to the vector if
1048     // they are loop invariant w.r.t. the recurrence.
1049     std::vector<SCEVHandle> LIOps;
1050     SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1051     for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1052       if (Ops[i]->isLoopInvariant(AddRec->getLoop())) {
1053         LIOps.push_back(Ops[i]);
1054         Ops.erase(Ops.begin()+i);
1055         --i; --e;
1056       }
1057
1058     // If we found some loop invariants, fold them into the recurrence.
1059     if (!LIOps.empty()) {
1060       //  NLI * LI * {Start,+,Step}  -->  NLI * {LI*Start,+,LI*Step}
1061       std::vector<SCEVHandle> NewOps;
1062       NewOps.reserve(AddRec->getNumOperands());
1063       if (LIOps.size() == 1) {
1064         SCEV *Scale = LIOps[0];
1065         for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
1066           NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i)));
1067       } else {
1068         for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
1069           std::vector<SCEVHandle> MulOps(LIOps);
1070           MulOps.push_back(AddRec->getOperand(i));
1071           NewOps.push_back(getMulExpr(MulOps));
1072         }
1073       }
1074
1075       SCEVHandle NewRec = getAddRecExpr(NewOps, AddRec->getLoop());
1076
1077       // If all of the other operands were loop invariant, we are done.
1078       if (Ops.size() == 1) return NewRec;
1079
1080       // Otherwise, multiply the folded AddRec by the non-liv parts.
1081       for (unsigned i = 0;; ++i)
1082         if (Ops[i] == AddRec) {
1083           Ops[i] = NewRec;
1084           break;
1085         }
1086       return getMulExpr(Ops);
1087     }
1088
1089     // Okay, if there weren't any loop invariants to be folded, check to see if
1090     // there are multiple AddRec's with the same loop induction variable being
1091     // multiplied together.  If so, we can fold them.
1092     for (unsigned OtherIdx = Idx+1;
1093          OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx)
1094       if (OtherIdx != Idx) {
1095         SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
1096         if (AddRec->getLoop() == OtherAddRec->getLoop()) {
1097           // F * G  -->  {A,+,B} * {C,+,D}  -->  {A*C,+,F*D + G*B + B*D}
1098           SCEVAddRecExpr *F = AddRec, *G = OtherAddRec;
1099           SCEVHandle NewStart = getMulExpr(F->getStart(),
1100                                                  G->getStart());
1101           SCEVHandle B = F->getStepRecurrence(*this);
1102           SCEVHandle D = G->getStepRecurrence(*this);
1103           SCEVHandle NewStep = getAddExpr(getMulExpr(F, D),
1104                                           getMulExpr(G, B),
1105                                           getMulExpr(B, D));
1106           SCEVHandle NewAddRec = getAddRecExpr(NewStart, NewStep,
1107                                                F->getLoop());
1108           if (Ops.size() == 2) return NewAddRec;
1109
1110           Ops.erase(Ops.begin()+Idx);
1111           Ops.erase(Ops.begin()+OtherIdx-1);
1112           Ops.push_back(NewAddRec);
1113           return getMulExpr(Ops);
1114         }
1115       }
1116
1117     // Otherwise couldn't fold anything into this recurrence.  Move onto the
1118     // next one.
1119   }
1120
1121   // Okay, it looks like we really DO need an mul expr.  Check to see if we
1122   // already have one, otherwise create a new one.
1123   std::vector<SCEV*> SCEVOps(Ops.begin(), Ops.end());
1124   SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scMulExpr,
1125                                                                  SCEVOps)];
1126   if (Result == 0)
1127     Result = new SCEVMulExpr(Ops);
1128   return Result;
1129 }
1130
1131 SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, const SCEVHandle &RHS) {
1132   if (LHS == RHS)
1133     return getIntegerSCEV(1, LHS->getType());  // X udiv X --> 1
1134
1135   if (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
1136     if (RHSC->getValue()->equalsInt(1))
1137       return LHS;                              // X udiv 1 --> X
1138
1139     if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
1140       Constant *LHSCV = LHSC->getValue();
1141       Constant *RHSCV = RHSC->getValue();
1142       return getUnknown(ConstantExpr::getUDiv(LHSCV, RHSCV));
1143     }
1144   }
1145
1146   SCEVUDivExpr *&Result = (*SCEVUDivs)[std::make_pair(LHS, RHS)];
1147   if (Result == 0) Result = new SCEVUDivExpr(LHS, RHS);
1148   return Result;
1149 }
1150
1151 SCEVHandle ScalarEvolution::getSDivExpr(const SCEVHandle &LHS, const SCEVHandle &RHS) {
1152   if (LHS == RHS)                            
1153     return getIntegerSCEV(1, LHS->getType());  // X sdiv X --> 1
1154
1155   if (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
1156     if (RHSC->getValue()->equalsInt(1))
1157       return LHS;                              // X sdiv 1 --> X
1158
1159     if (RHSC->getValue()->isAllOnesValue())
1160       return getNegativeSCEV(LHS);             // X sdiv -1 --> -X
1161
1162     if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
1163       Constant *LHSCV = LHSC->getValue();
1164       Constant *RHSCV = RHSC->getValue();
1165       return getUnknown(ConstantExpr::getSDiv(LHSCV, RHSCV));
1166     }
1167   }
1168
1169   SCEVSDivExpr *&Result = (*SCEVSDivs)[std::make_pair(LHS, RHS)];
1170   if (Result == 0) Result = new SCEVSDivExpr(LHS, RHS);
1171   return Result;
1172 }
1173
1174
1175 /// SCEVAddRecExpr::get - Get a add recurrence expression for the
1176 /// specified loop.  Simplify the expression as much as possible.
1177 SCEVHandle ScalarEvolution::getAddRecExpr(const SCEVHandle &Start,
1178                                const SCEVHandle &Step, const Loop *L) {
1179   std::vector<SCEVHandle> Operands;
1180   Operands.push_back(Start);
1181   if (SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
1182     if (StepChrec->getLoop() == L) {
1183       Operands.insert(Operands.end(), StepChrec->op_begin(),
1184                       StepChrec->op_end());
1185       return getAddRecExpr(Operands, L);
1186     }
1187
1188   Operands.push_back(Step);
1189   return getAddRecExpr(Operands, L);
1190 }
1191
1192 /// SCEVAddRecExpr::get - Get a add recurrence expression for the
1193 /// specified loop.  Simplify the expression as much as possible.
1194 SCEVHandle ScalarEvolution::getAddRecExpr(std::vector<SCEVHandle> &Operands,
1195                                const Loop *L) {
1196   if (Operands.size() == 1) return Operands[0];
1197
1198   if (Operands.back()->isZero()) {
1199     Operands.pop_back();
1200     return getAddRecExpr(Operands, L);             // {X,+,0}  -->  X
1201   }
1202
1203   // Canonicalize nested AddRecs in by nesting them in order of loop depth.
1204   if (SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
1205     const Loop* NestedLoop = NestedAR->getLoop();
1206     if (L->getLoopDepth() < NestedLoop->getLoopDepth()) {
1207       std::vector<SCEVHandle> NestedOperands(NestedAR->op_begin(),
1208                                              NestedAR->op_end());
1209       SCEVHandle NestedARHandle(NestedAR);
1210       Operands[0] = NestedAR->getStart();
1211       NestedOperands[0] = getAddRecExpr(Operands, L);
1212       return getAddRecExpr(NestedOperands, NestedLoop);
1213     }
1214   }
1215
1216   SCEVAddRecExpr *&Result =
1217     (*SCEVAddRecExprs)[std::make_pair(L, std::vector<SCEV*>(Operands.begin(),
1218                                                             Operands.end()))];
1219   if (Result == 0) Result = new SCEVAddRecExpr(Operands, L);
1220   return Result;
1221 }
1222
1223 SCEVHandle ScalarEvolution::getSMaxExpr(const SCEVHandle &LHS,
1224                                         const SCEVHandle &RHS) {
1225   std::vector<SCEVHandle> Ops;
1226   Ops.push_back(LHS);
1227   Ops.push_back(RHS);
1228   return getSMaxExpr(Ops);
1229 }
1230
1231 SCEVHandle ScalarEvolution::getSMaxExpr(std::vector<SCEVHandle> Ops) {
1232   assert(!Ops.empty() && "Cannot get empty smax!");
1233   if (Ops.size() == 1) return Ops[0];
1234
1235   // Sort by complexity, this groups all similar expression types together.
1236   GroupByComplexity(Ops);
1237
1238   // If there are any constants, fold them together.
1239   unsigned Idx = 0;
1240   if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1241     ++Idx;
1242     assert(Idx < Ops.size());
1243     while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1244       // We found two constants, fold them together!
1245       ConstantInt *Fold = ConstantInt::get(
1246                               APIntOps::smax(LHSC->getValue()->getValue(),
1247                                              RHSC->getValue()->getValue()));
1248       Ops[0] = getConstant(Fold);
1249       Ops.erase(Ops.begin()+1);  // Erase the folded element
1250       if (Ops.size() == 1) return Ops[0];
1251       LHSC = cast<SCEVConstant>(Ops[0]);
1252     }
1253
1254     // If we are left with a constant -inf, strip it off.
1255     if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
1256       Ops.erase(Ops.begin());
1257       --Idx;
1258     }
1259   }
1260
1261   if (Ops.size() == 1) return Ops[0];
1262
1263   // Find the first SMax
1264   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
1265     ++Idx;
1266
1267   // Check to see if one of the operands is an SMax. If so, expand its operands
1268   // onto our operand list, and recurse to simplify.
1269   if (Idx < Ops.size()) {
1270     bool DeletedSMax = false;
1271     while (SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
1272       Ops.insert(Ops.end(), SMax->op_begin(), SMax->op_end());
1273       Ops.erase(Ops.begin()+Idx);
1274       DeletedSMax = true;
1275     }
1276
1277     if (DeletedSMax)
1278       return getSMaxExpr(Ops);
1279   }
1280
1281   // Okay, check to see if the same value occurs in the operand list twice.  If
1282   // so, delete one.  Since we sorted the list, these values are required to
1283   // be adjacent.
1284   for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
1285     if (Ops[i] == Ops[i+1]) {      //  X smax Y smax Y  -->  X smax Y
1286       Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
1287       --i; --e;
1288     }
1289
1290   if (Ops.size() == 1) return Ops[0];
1291
1292   assert(!Ops.empty() && "Reduced smax down to nothing!");
1293
1294   // Okay, it looks like we really DO need an smax expr.  Check to see if we
1295   // already have one, otherwise create a new one.
1296   std::vector<SCEV*> SCEVOps(Ops.begin(), Ops.end());
1297   SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scSMaxExpr,
1298                                                                  SCEVOps)];
1299   if (Result == 0) Result = new SCEVSMaxExpr(Ops);
1300   return Result;
1301 }
1302
1303 SCEVHandle ScalarEvolution::getUMaxExpr(const SCEVHandle &LHS,
1304                                         const SCEVHandle &RHS) {
1305   std::vector<SCEVHandle> Ops;
1306   Ops.push_back(LHS);
1307   Ops.push_back(RHS);
1308   return getUMaxExpr(Ops);
1309 }
1310
1311 SCEVHandle ScalarEvolution::getUMaxExpr(std::vector<SCEVHandle> Ops) {
1312   assert(!Ops.empty() && "Cannot get empty umax!");
1313   if (Ops.size() == 1) return Ops[0];
1314
1315   // Sort by complexity, this groups all similar expression types together.
1316   GroupByComplexity(Ops);
1317
1318   // If there are any constants, fold them together.
1319   unsigned Idx = 0;
1320   if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1321     ++Idx;
1322     assert(Idx < Ops.size());
1323     while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1324       // We found two constants, fold them together!
1325       ConstantInt *Fold = ConstantInt::get(
1326                               APIntOps::umax(LHSC->getValue()->getValue(),
1327                                              RHSC->getValue()->getValue()));
1328       Ops[0] = getConstant(Fold);
1329       Ops.erase(Ops.begin()+1);  // Erase the folded element
1330       if (Ops.size() == 1) return Ops[0];
1331       LHSC = cast<SCEVConstant>(Ops[0]);
1332     }
1333
1334     // If we are left with a constant zero, strip it off.
1335     if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) {
1336       Ops.erase(Ops.begin());
1337       --Idx;
1338     }
1339   }
1340
1341   if (Ops.size() == 1) return Ops[0];
1342
1343   // Find the first UMax
1344   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr)
1345     ++Idx;
1346
1347   // Check to see if one of the operands is a UMax. If so, expand its operands
1348   // onto our operand list, and recurse to simplify.
1349   if (Idx < Ops.size()) {
1350     bool DeletedUMax = false;
1351     while (SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) {
1352       Ops.insert(Ops.end(), UMax->op_begin(), UMax->op_end());
1353       Ops.erase(Ops.begin()+Idx);
1354       DeletedUMax = true;
1355     }
1356
1357     if (DeletedUMax)
1358       return getUMaxExpr(Ops);
1359   }
1360
1361   // Okay, check to see if the same value occurs in the operand list twice.  If
1362   // so, delete one.  Since we sorted the list, these values are required to
1363   // be adjacent.
1364   for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
1365     if (Ops[i] == Ops[i+1]) {      //  X umax Y umax Y  -->  X umax Y
1366       Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
1367       --i; --e;
1368     }
1369
1370   if (Ops.size() == 1) return Ops[0];
1371
1372   assert(!Ops.empty() && "Reduced umax down to nothing!");
1373
1374   // Okay, it looks like we really DO need a umax expr.  Check to see if we
1375   // already have one, otherwise create a new one.
1376   std::vector<SCEV*> SCEVOps(Ops.begin(), Ops.end());
1377   SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scUMaxExpr,
1378                                                                  SCEVOps)];
1379   if (Result == 0) Result = new SCEVUMaxExpr(Ops);
1380   return Result;
1381 }
1382
1383 SCEVHandle ScalarEvolution::getUnknown(Value *V) {
1384   if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
1385     return getConstant(CI);
1386   SCEVUnknown *&Result = (*SCEVUnknowns)[V];
1387   if (Result == 0) Result = new SCEVUnknown(V);
1388   return Result;
1389 }
1390
1391
1392 //===----------------------------------------------------------------------===//
1393 //             ScalarEvolutionsImpl Definition and Implementation
1394 //===----------------------------------------------------------------------===//
1395 //
1396 /// ScalarEvolutionsImpl - This class implements the main driver for the scalar
1397 /// evolution code.
1398 ///
1399 namespace {
1400   struct VISIBILITY_HIDDEN ScalarEvolutionsImpl {
1401     /// SE - A reference to the public ScalarEvolution object.
1402     ScalarEvolution &SE;
1403
1404     /// F - The function we are analyzing.
1405     ///
1406     Function &F;
1407
1408     /// LI - The loop information for the function we are currently analyzing.
1409     ///
1410     LoopInfo &LI;
1411
1412     /// UnknownValue - This SCEV is used to represent unknown trip counts and
1413     /// things.
1414     SCEVHandle UnknownValue;
1415
1416     /// Scalars - This is a cache of the scalars we have analyzed so far.
1417     ///
1418     std::map<Value*, SCEVHandle> Scalars;
1419
1420     /// IterationCounts - Cache the iteration count of the loops for this
1421     /// function as they are computed.
1422     std::map<const Loop*, SCEVHandle> IterationCounts;
1423
1424     /// ConstantEvolutionLoopExitValue - This map contains entries for all of
1425     /// the PHI instructions that we attempt to compute constant evolutions for.
1426     /// This allows us to avoid potentially expensive recomputation of these
1427     /// properties.  An instruction maps to null if we are unable to compute its
1428     /// exit value.
1429     std::map<PHINode*, Constant*> ConstantEvolutionLoopExitValue;
1430
1431   public:
1432     ScalarEvolutionsImpl(ScalarEvolution &se, Function &f, LoopInfo &li)
1433       : SE(se), F(f), LI(li), UnknownValue(new SCEVCouldNotCompute()) {}
1434
1435     /// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
1436     /// expression and create a new one.
1437     SCEVHandle getSCEV(Value *V);
1438
1439     /// hasSCEV - Return true if the SCEV for this value has already been
1440     /// computed.
1441     bool hasSCEV(Value *V) const {
1442       return Scalars.count(V);
1443     }
1444
1445     /// setSCEV - Insert the specified SCEV into the map of current SCEVs for
1446     /// the specified value.
1447     void setSCEV(Value *V, const SCEVHandle &H) {
1448       bool isNew = Scalars.insert(std::make_pair(V, H)).second;
1449       assert(isNew && "This entry already existed!");
1450       isNew = false;
1451     }
1452
1453
1454     /// getSCEVAtScope - Compute the value of the specified expression within
1455     /// the indicated loop (which may be null to indicate in no loop).  If the
1456     /// expression cannot be evaluated, return UnknownValue itself.
1457     SCEVHandle getSCEVAtScope(SCEV *V, const Loop *L);
1458
1459
1460     /// hasLoopInvariantIterationCount - Return true if the specified loop has
1461     /// an analyzable loop-invariant iteration count.
1462     bool hasLoopInvariantIterationCount(const Loop *L);
1463
1464     /// getIterationCount - If the specified loop has a predictable iteration
1465     /// count, return it.  Note that it is not valid to call this method on a
1466     /// loop without a loop-invariant iteration count.
1467     SCEVHandle getIterationCount(const Loop *L);
1468
1469     /// deleteValueFromRecords - This method should be called by the
1470     /// client before it removes a value from the program, to make sure
1471     /// that no dangling references are left around.
1472     void deleteValueFromRecords(Value *V);
1473
1474   private:
1475     /// createSCEV - We know that there is no SCEV for the specified value.
1476     /// Analyze the expression.
1477     SCEVHandle createSCEV(Value *V);
1478
1479     /// createNodeForPHI - Provide the special handling we need to analyze PHI
1480     /// SCEVs.
1481     SCEVHandle createNodeForPHI(PHINode *PN);
1482
1483     /// ReplaceSymbolicValueWithConcrete - This looks up the computed SCEV value
1484     /// for the specified instruction and replaces any references to the
1485     /// symbolic value SymName with the specified value.  This is used during
1486     /// PHI resolution.
1487     void ReplaceSymbolicValueWithConcrete(Instruction *I,
1488                                           const SCEVHandle &SymName,
1489                                           const SCEVHandle &NewVal);
1490
1491     /// ComputeIterationCount - Compute the number of times the specified loop
1492     /// will iterate.
1493     SCEVHandle ComputeIterationCount(const Loop *L);
1494
1495     /// ComputeLoadConstantCompareIterationCount - Given an exit condition of
1496     /// 'icmp op load X, cst', try to see if we can compute the trip count.
1497     SCEVHandle ComputeLoadConstantCompareIterationCount(LoadInst *LI,
1498                                                         Constant *RHS,
1499                                                         const Loop *L,
1500                                                         ICmpInst::Predicate p);
1501
1502     /// ComputeIterationCountExhaustively - If the trip is known to execute a
1503     /// constant number of times (the condition evolves only from constants),
1504     /// try to evaluate a few iterations of the loop until we get the exit
1505     /// condition gets a value of ExitWhen (true or false).  If we cannot
1506     /// evaluate the trip count of the loop, return UnknownValue.
1507     SCEVHandle ComputeIterationCountExhaustively(const Loop *L, Value *Cond,
1508                                                  bool ExitWhen);
1509
1510     /// HowFarToZero - Return the number of times a backedge comparing the
1511     /// specified value to zero will execute.  If not computable, return
1512     /// UnknownValue.
1513     SCEVHandle HowFarToZero(SCEV *V, const Loop *L);
1514
1515     /// HowFarToNonZero - Return the number of times a backedge checking the
1516     /// specified value for nonzero will execute.  If not computable, return
1517     /// UnknownValue.
1518     SCEVHandle HowFarToNonZero(SCEV *V, const Loop *L);
1519
1520     /// HowManyLessThans - Return the number of times a backedge containing the
1521     /// specified less-than comparison will execute.  If not computable, return
1522     /// UnknownValue. isSigned specifies whether the less-than is signed.
1523     SCEVHandle HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L,
1524                                 bool isSigned, bool trueWhenEqual);
1525
1526     /// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB
1527     /// (which may not be an immediate predecessor) which has exactly one
1528     /// successor from which BB is reachable, or null if no such block is
1529     /// found.
1530     BasicBlock* getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB);
1531
1532     /// executesAtLeastOnce - Test whether entry to the loop is protected by
1533     /// a conditional between LHS and RHS.
1534     bool executesAtLeastOnce(const Loop *L, bool isSigned, bool trueWhenEqual,
1535                              SCEV *LHS, SCEV *RHS);
1536
1537     /// potentialInfiniteLoop - Test whether the loop might jump over the exit value
1538     /// due to wrapping.
1539     bool potentialInfiniteLoop(SCEV *Stride, SCEV *RHS, bool isSigned,
1540                                bool trueWhenEqual);
1541
1542     /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
1543     /// in the header of its containing loop, we know the loop executes a
1544     /// constant number of times, and the PHI node is just a recurrence
1545     /// involving constants, fold it.
1546     Constant *getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& Its,
1547                                                 const Loop *L);
1548   };
1549 }
1550
1551 //===----------------------------------------------------------------------===//
1552 //            Basic SCEV Analysis and PHI Idiom Recognition Code
1553 //
1554
1555 /// deleteValueFromRecords - This method should be called by the
1556 /// client before it removes an instruction from the program, to make sure
1557 /// that no dangling references are left around.
1558 void ScalarEvolutionsImpl::deleteValueFromRecords(Value *V) {
1559   SmallVector<Value *, 16> Worklist;
1560
1561   if (Scalars.erase(V)) {
1562     if (PHINode *PN = dyn_cast<PHINode>(V))
1563       ConstantEvolutionLoopExitValue.erase(PN);
1564     Worklist.push_back(V);
1565   }
1566
1567   while (!Worklist.empty()) {
1568     Value *VV = Worklist.back();
1569     Worklist.pop_back();
1570
1571     for (Instruction::use_iterator UI = VV->use_begin(), UE = VV->use_end();
1572          UI != UE; ++UI) {
1573       Instruction *Inst = cast<Instruction>(*UI);
1574       if (Scalars.erase(Inst)) {
1575         if (PHINode *PN = dyn_cast<PHINode>(VV))
1576           ConstantEvolutionLoopExitValue.erase(PN);
1577         Worklist.push_back(Inst);
1578       }
1579     }
1580   }
1581 }
1582
1583
1584 /// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
1585 /// expression and create a new one.
1586 SCEVHandle ScalarEvolutionsImpl::getSCEV(Value *V) {
1587   assert(V->getType() != Type::VoidTy && "Can't analyze void expressions!");
1588
1589   std::map<Value*, SCEVHandle>::iterator I = Scalars.find(V);
1590   if (I != Scalars.end()) return I->second;
1591   SCEVHandle S = createSCEV(V);
1592   Scalars.insert(std::make_pair(V, S));
1593   return S;
1594 }
1595
1596 /// ReplaceSymbolicValueWithConcrete - This looks up the computed SCEV value for
1597 /// the specified instruction and replaces any references to the symbolic value
1598 /// SymName with the specified value.  This is used during PHI resolution.
1599 void ScalarEvolutionsImpl::
1600 ReplaceSymbolicValueWithConcrete(Instruction *I, const SCEVHandle &SymName,
1601                                  const SCEVHandle &NewVal) {
1602   std::map<Value*, SCEVHandle>::iterator SI = Scalars.find(I);
1603   if (SI == Scalars.end()) return;
1604
1605   SCEVHandle NV =
1606     SI->second->replaceSymbolicValuesWithConcrete(SymName, NewVal, SE);
1607   if (NV == SI->second) return;  // No change.
1608
1609   SI->second = NV;       // Update the scalars map!
1610
1611   // Any instruction values that use this instruction might also need to be
1612   // updated!
1613   for (Value::use_iterator UI = I->use_begin(), E = I->use_end();
1614        UI != E; ++UI)
1615     ReplaceSymbolicValueWithConcrete(cast<Instruction>(*UI), SymName, NewVal);
1616 }
1617
1618 /// createNodeForPHI - PHI nodes have two cases.  Either the PHI node exists in
1619 /// a loop header, making it a potential recurrence, or it doesn't.
1620 ///
1621 SCEVHandle ScalarEvolutionsImpl::createNodeForPHI(PHINode *PN) {
1622   if (PN->getNumIncomingValues() == 2)  // The loops have been canonicalized.
1623     if (const Loop *L = LI.getLoopFor(PN->getParent()))
1624       if (L->getHeader() == PN->getParent()) {
1625         // If it lives in the loop header, it has two incoming values, one
1626         // from outside the loop, and one from inside.
1627         unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0));
1628         unsigned BackEdge     = IncomingEdge^1;
1629
1630         // While we are analyzing this PHI node, handle its value symbolically.
1631         SCEVHandle SymbolicName = SE.getUnknown(PN);
1632         assert(Scalars.find(PN) == Scalars.end() &&
1633                "PHI node already processed?");
1634         Scalars.insert(std::make_pair(PN, SymbolicName));
1635
1636         // Using this symbolic name for the PHI, analyze the value coming around
1637         // the back-edge.
1638         SCEVHandle BEValue = getSCEV(PN->getIncomingValue(BackEdge));
1639
1640         // NOTE: If BEValue is loop invariant, we know that the PHI node just
1641         // has a special value for the first iteration of the loop.
1642
1643         // If the value coming around the backedge is an add with the symbolic
1644         // value we just inserted, then we found a simple induction variable!
1645         if (SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
1646           // If there is a single occurrence of the symbolic value, replace it
1647           // with a recurrence.
1648           unsigned FoundIndex = Add->getNumOperands();
1649           for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
1650             if (Add->getOperand(i) == SymbolicName)
1651               if (FoundIndex == e) {
1652                 FoundIndex = i;
1653                 break;
1654               }
1655
1656           if (FoundIndex != Add->getNumOperands()) {
1657             // Create an add with everything but the specified operand.
1658             std::vector<SCEVHandle> Ops;
1659             for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
1660               if (i != FoundIndex)
1661                 Ops.push_back(Add->getOperand(i));
1662             SCEVHandle Accum = SE.getAddExpr(Ops);
1663
1664             // This is not a valid addrec if the step amount is varying each
1665             // loop iteration, but is not itself an addrec in this loop.
1666             if (Accum->isLoopInvariant(L) ||
1667                 (isa<SCEVAddRecExpr>(Accum) &&
1668                  cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
1669               SCEVHandle StartVal = getSCEV(PN->getIncomingValue(IncomingEdge));
1670               SCEVHandle PHISCEV  = SE.getAddRecExpr(StartVal, Accum, L);
1671
1672               // Okay, for the entire analysis of this edge we assumed the PHI
1673               // to be symbolic.  We now need to go back and update all of the
1674               // entries for the scalars that use the PHI (except for the PHI
1675               // itself) to use the new analyzed value instead of the "symbolic"
1676               // value.
1677               ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV);
1678               return PHISCEV;
1679             }
1680           }
1681         } else if (SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(BEValue)) {
1682           // Otherwise, this could be a loop like this:
1683           //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
1684           // In this case, j = {1,+,1}  and BEValue is j.
1685           // Because the other in-value of i (0) fits the evolution of BEValue
1686           // i really is an addrec evolution.
1687           if (AddRec->getLoop() == L && AddRec->isAffine()) {
1688             SCEVHandle StartVal = getSCEV(PN->getIncomingValue(IncomingEdge));
1689
1690             // If StartVal = j.start - j.stride, we can use StartVal as the
1691             // initial step of the addrec evolution.
1692             if (StartVal == SE.getMinusSCEV(AddRec->getOperand(0),
1693                                             AddRec->getOperand(1))) {
1694               SCEVHandle PHISCEV = 
1695                  SE.getAddRecExpr(StartVal, AddRec->getOperand(1), L);
1696
1697               // Okay, for the entire analysis of this edge we assumed the PHI
1698               // to be symbolic.  We now need to go back and update all of the
1699               // entries for the scalars that use the PHI (except for the PHI
1700               // itself) to use the new analyzed value instead of the "symbolic"
1701               // value.
1702               ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV);
1703               return PHISCEV;
1704             }
1705           }
1706         }
1707
1708         return SymbolicName;
1709       }
1710
1711   // If it's not a loop phi, we can't handle it yet.
1712   return SE.getUnknown(PN);
1713 }
1714
1715 /// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
1716 /// guaranteed to end in (at every loop iteration).  It is, at the same time,
1717 /// the minimum number of times S is divisible by 2.  For example, given {4,+,8}
1718 /// it returns 2.  If S is guaranteed to be 0, it returns the bitwidth of S.
1719 static uint32_t GetMinTrailingZeros(SCEVHandle S) {
1720   if (SCEVConstant *C = dyn_cast<SCEVConstant>(S))
1721     return C->getValue()->getValue().countTrailingZeros();
1722
1723   if (SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
1724     return std::min(GetMinTrailingZeros(T->getOperand()), T->getBitWidth());
1725
1726   if (SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
1727     uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
1728     return OpRes == E->getOperand()->getBitWidth() ? E->getBitWidth() : OpRes;
1729   }
1730
1731   if (SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
1732     uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
1733     return OpRes == E->getOperand()->getBitWidth() ? E->getBitWidth() : OpRes;
1734   }
1735
1736   if (SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
1737     // The result is the min of all operands results.
1738     uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
1739     for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
1740       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
1741     return MinOpRes;
1742   }
1743
1744   if (SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
1745     // The result is the sum of all operands results.
1746     uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
1747     uint32_t BitWidth = M->getBitWidth();
1748     for (unsigned i = 1, e = M->getNumOperands();
1749          SumOpRes != BitWidth && i != e; ++i)
1750       SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
1751                           BitWidth);
1752     return SumOpRes;
1753   }
1754
1755   if (SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
1756     // The result is the min of all operands results.
1757     uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
1758     for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
1759       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
1760     return MinOpRes;
1761   }
1762
1763   if (SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
1764     // The result is the min of all operands results.
1765     uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
1766     for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
1767       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
1768     return MinOpRes;
1769   }
1770
1771   if (SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
1772     // The result is the min of all operands results.
1773     uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
1774     for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
1775       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
1776     return MinOpRes;
1777   }
1778
1779   // SCEVUDivExpr, SCEVSDivExpr, SCEVUnknown
1780   return 0;
1781 }
1782
1783 /// createSCEV - We know that there is no SCEV for the specified value.
1784 /// Analyze the expression.
1785 ///
1786 SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
1787   if (!isa<IntegerType>(V->getType()))
1788     return SE.getUnknown(V);
1789     
1790   unsigned Opcode = Instruction::UserOp1;
1791   if (Instruction *I = dyn_cast<Instruction>(V))
1792     Opcode = I->getOpcode();
1793   else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
1794     Opcode = CE->getOpcode();
1795   else
1796     return SE.getUnknown(V);
1797
1798   User *U = cast<User>(V);
1799   switch (Opcode) {
1800   case Instruction::Add:
1801     return SE.getAddExpr(getSCEV(U->getOperand(0)),
1802                          getSCEV(U->getOperand(1)));
1803   case Instruction::Mul:
1804     return SE.getMulExpr(getSCEV(U->getOperand(0)),
1805                          getSCEV(U->getOperand(1)));
1806   case Instruction::UDiv:
1807     return SE.getUDivExpr(getSCEV(U->getOperand(0)),
1808                           getSCEV(U->getOperand(1)));
1809   case Instruction::SDiv:
1810     return SE.getSDivExpr(getSCEV(U->getOperand(0)),
1811                           getSCEV(U->getOperand(1)));
1812   case Instruction::Sub:
1813     return SE.getMinusSCEV(getSCEV(U->getOperand(0)),
1814                            getSCEV(U->getOperand(1)));
1815   case Instruction::Or:
1816     // If the RHS of the Or is a constant, we may have something like:
1817     // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
1818     // optimizations will transparently handle this case.
1819     //
1820     // In order for this transformation to be safe, the LHS must be of the
1821     // form X*(2^n) and the Or constant must be less than 2^n.
1822     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
1823       SCEVHandle LHS = getSCEV(U->getOperand(0));
1824       const APInt &CIVal = CI->getValue();
1825       if (GetMinTrailingZeros(LHS) >=
1826           (CIVal.getBitWidth() - CIVal.countLeadingZeros()))
1827         return SE.getAddExpr(LHS, getSCEV(U->getOperand(1)));
1828     }
1829     break;
1830   case Instruction::Xor:
1831     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
1832       // If the RHS of the xor is a signbit, then this is just an add.
1833       // Instcombine turns add of signbit into xor as a strength reduction step.
1834       if (CI->getValue().isSignBit())
1835         return SE.getAddExpr(getSCEV(U->getOperand(0)),
1836                              getSCEV(U->getOperand(1)));
1837
1838       // If the RHS of xor is -1, then this is a not operation.
1839       else if (CI->isAllOnesValue())
1840         return SE.getNotSCEV(getSCEV(U->getOperand(0)));
1841     }
1842     break;
1843
1844   case Instruction::Shl:
1845     // Turn shift left of a constant amount into a multiply.
1846     if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
1847       uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
1848       Constant *X = ConstantInt::get(
1849         APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
1850       return SE.getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
1851     }
1852     break;
1853
1854   case Instruction::LShr:
1855     // Turn logical shift right of a constant into an unsigned divide.
1856     if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
1857       uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
1858       Constant *X = ConstantInt::get(
1859         APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
1860       return SE.getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X));
1861     }
1862     break;
1863
1864   case Instruction::Trunc:
1865     return SE.getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
1866
1867   case Instruction::ZExt:
1868     return SE.getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
1869
1870   case Instruction::SExt:
1871     return SE.getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
1872
1873   case Instruction::BitCast:
1874     // BitCasts are no-op casts so we just eliminate the cast.
1875     if (U->getType()->isInteger() &&
1876         U->getOperand(0)->getType()->isInteger())
1877       return getSCEV(U->getOperand(0));
1878     break;
1879
1880   case Instruction::PHI:
1881     return createNodeForPHI(cast<PHINode>(U));
1882
1883   case Instruction::Select:
1884     // This could be a smax or umax that was lowered earlier.
1885     // Try to recover it.
1886     if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
1887       Value *LHS = ICI->getOperand(0);
1888       Value *RHS = ICI->getOperand(1);
1889       switch (ICI->getPredicate()) {
1890       case ICmpInst::ICMP_SLT:
1891       case ICmpInst::ICMP_SLE:
1892         std::swap(LHS, RHS);
1893         // fall through
1894       case ICmpInst::ICMP_SGT:
1895       case ICmpInst::ICMP_SGE:
1896         if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
1897           return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
1898         else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
1899           // ~smax(~x, ~y) == smin(x, y).
1900           return SE.getNotSCEV(SE.getSMaxExpr(
1901                                    SE.getNotSCEV(getSCEV(LHS)),
1902                                    SE.getNotSCEV(getSCEV(RHS))));
1903         break;
1904       case ICmpInst::ICMP_ULT:
1905       case ICmpInst::ICMP_ULE:
1906         std::swap(LHS, RHS);
1907         // fall through
1908       case ICmpInst::ICMP_UGT:
1909       case ICmpInst::ICMP_UGE:
1910         if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
1911           return SE.getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
1912         else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
1913           // ~umax(~x, ~y) == umin(x, y)
1914           return SE.getNotSCEV(SE.getUMaxExpr(SE.getNotSCEV(getSCEV(LHS)),
1915                                               SE.getNotSCEV(getSCEV(RHS))));
1916         break;
1917       default:
1918         break;
1919       }
1920     }
1921
1922   default: // We cannot analyze this expression.
1923     break;
1924   }
1925
1926   return SE.getUnknown(V);
1927 }
1928
1929
1930
1931 //===----------------------------------------------------------------------===//
1932 //                   Iteration Count Computation Code
1933 //
1934
1935 /// getIterationCount - If the specified loop has a predictable iteration
1936 /// count, return it.  Note that it is not valid to call this method on a
1937 /// loop without a loop-invariant iteration count.
1938 SCEVHandle ScalarEvolutionsImpl::getIterationCount(const Loop *L) {
1939   std::map<const Loop*, SCEVHandle>::iterator I = IterationCounts.find(L);
1940   if (I == IterationCounts.end()) {
1941     SCEVHandle ItCount = ComputeIterationCount(L);
1942     I = IterationCounts.insert(std::make_pair(L, ItCount)).first;
1943     if (ItCount != UnknownValue) {
1944       assert(ItCount->isLoopInvariant(L) &&
1945              "Computed trip count isn't loop invariant for loop!");
1946       ++NumTripCountsComputed;
1947     } else if (isa<PHINode>(L->getHeader()->begin())) {
1948       // Only count loops that have phi nodes as not being computable.
1949       ++NumTripCountsNotComputed;
1950     }
1951   }
1952   return I->second;
1953 }
1954
1955 /// ComputeIterationCount - Compute the number of times the specified loop
1956 /// will iterate.
1957 SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) {
1958   // If the loop has a non-one exit block count, we can't analyze it.
1959   SmallVector<BasicBlock*, 8> ExitBlocks;
1960   L->getExitBlocks(ExitBlocks);
1961   if (ExitBlocks.size() != 1) return UnknownValue;
1962
1963   // Okay, there is one exit block.  Try to find the condition that causes the
1964   // loop to be exited.
1965   BasicBlock *ExitBlock = ExitBlocks[0];
1966
1967   BasicBlock *ExitingBlock = 0;
1968   for (pred_iterator PI = pred_begin(ExitBlock), E = pred_end(ExitBlock);
1969        PI != E; ++PI)
1970     if (L->contains(*PI)) {
1971       if (ExitingBlock == 0)
1972         ExitingBlock = *PI;
1973       else
1974         return UnknownValue;   // More than one block exiting!
1975     }
1976   assert(ExitingBlock && "No exits from loop, something is broken!");
1977
1978   // Okay, we've computed the exiting block.  See what condition causes us to
1979   // exit.
1980   //
1981   // FIXME: we should be able to handle switch instructions (with a single exit)
1982   BranchInst *ExitBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
1983   if (ExitBr == 0) return UnknownValue;
1984   assert(ExitBr->isConditional() && "If unconditional, it can't be in loop!");
1985   
1986   // At this point, we know we have a conditional branch that determines whether
1987   // the loop is exited.  However, we don't know if the branch is executed each
1988   // time through the loop.  If not, then the execution count of the branch will
1989   // not be equal to the trip count of the loop.
1990   //
1991   // Currently we check for this by checking to see if the Exit branch goes to
1992   // the loop header.  If so, we know it will always execute the same number of
1993   // times as the loop.  We also handle the case where the exit block *is* the
1994   // loop header.  This is common for un-rotated loops.  More extensive analysis
1995   // could be done to handle more cases here.
1996   if (ExitBr->getSuccessor(0) != L->getHeader() &&
1997       ExitBr->getSuccessor(1) != L->getHeader() &&
1998       ExitBr->getParent() != L->getHeader())
1999     return UnknownValue;
2000   
2001   ICmpInst *ExitCond = dyn_cast<ICmpInst>(ExitBr->getCondition());
2002
2003   // If it's not an integer comparison then compute it the hard way. 
2004   // Note that ICmpInst deals with pointer comparisons too so we must check
2005   // the type of the operand.
2006   if (ExitCond == 0 || isa<PointerType>(ExitCond->getOperand(0)->getType()))
2007     return ComputeIterationCountExhaustively(L, ExitBr->getCondition(),
2008                                           ExitBr->getSuccessor(0) == ExitBlock);
2009
2010   // If the condition was exit on true, convert the condition to exit on false
2011   ICmpInst::Predicate Cond;
2012   if (ExitBr->getSuccessor(1) == ExitBlock)
2013     Cond = ExitCond->getPredicate();
2014   else
2015     Cond = ExitCond->getInversePredicate();
2016
2017   // Handle common loops like: for (X = "string"; *X; ++X)
2018   if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
2019     if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
2020       SCEVHandle ItCnt =
2021         ComputeLoadConstantCompareIterationCount(LI, RHS, L, Cond);
2022       if (!isa<SCEVCouldNotCompute>(ItCnt)) return ItCnt;
2023     }
2024
2025   SCEVHandle LHS = getSCEV(ExitCond->getOperand(0));
2026   SCEVHandle RHS = getSCEV(ExitCond->getOperand(1));
2027
2028   // Try to evaluate any dependencies out of the loop.
2029   SCEVHandle Tmp = getSCEVAtScope(LHS, L);
2030   if (!isa<SCEVCouldNotCompute>(Tmp)) LHS = Tmp;
2031   Tmp = getSCEVAtScope(RHS, L);
2032   if (!isa<SCEVCouldNotCompute>(Tmp)) RHS = Tmp;
2033
2034   // At this point, we would like to compute how many iterations of the 
2035   // loop the predicate will return true for these inputs.
2036   if (LHS->isLoopInvariant(L) && !RHS->isLoopInvariant(L)) {
2037     // If there is a loop-invariant, force it into the RHS.
2038     std::swap(LHS, RHS);
2039     Cond = ICmpInst::getSwappedPredicate(Cond);
2040   }
2041
2042   // FIXME: think about handling pointer comparisons!  i.e.:
2043   // while (P != P+100) ++P;
2044
2045   // If we have a comparison of a chrec against a constant, try to use value
2046   // ranges to answer this query.
2047   if (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
2048     if (SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
2049       if (AddRec->getLoop() == L) {
2050         // Form the comparison range using the constant of the correct type so
2051         // that the ConstantRange class knows to do a signed or unsigned
2052         // comparison.
2053         ConstantInt *CompVal = RHSC->getValue();
2054         const Type *RealTy = ExitCond->getOperand(0)->getType();
2055         CompVal = dyn_cast<ConstantInt>(
2056           ConstantExpr::getBitCast(CompVal, RealTy));
2057         if (CompVal) {
2058           // Form the constant range.
2059           ConstantRange CompRange(
2060               ICmpInst::makeConstantRange(Cond, CompVal->getValue()));
2061
2062           SCEVHandle Ret = AddRec->getNumIterationsInRange(CompRange, SE);
2063           if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
2064         }
2065       }
2066
2067   switch (Cond) {
2068   case ICmpInst::ICMP_NE: {                     // while (X != Y)
2069     // Convert to: while (X-Y != 0)
2070     SCEVHandle TC = HowFarToZero(SE.getMinusSCEV(LHS, RHS), L);
2071     if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2072     break;
2073   }
2074   case ICmpInst::ICMP_EQ: {
2075     // Convert to: while (X-Y == 0)           // while (X == Y)
2076     SCEVHandle TC = HowFarToNonZero(SE.getMinusSCEV(LHS, RHS), L);
2077     if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2078     break;
2079   }
2080   case ICmpInst::ICMP_SLT: {
2081     SCEVHandle TC = HowManyLessThans(LHS, RHS, L, true, false);
2082     if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2083     break;
2084   }
2085   case ICmpInst::ICMP_SGT: {
2086     SCEVHandle TC = HowManyLessThans(SE.getNotSCEV(LHS),
2087                                      SE.getNotSCEV(RHS), L, true, false);
2088     if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2089     break;
2090   }
2091   case ICmpInst::ICMP_ULT: {
2092     SCEVHandle TC = HowManyLessThans(LHS, RHS, L, false, false);
2093     if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2094     break;
2095   }
2096   case ICmpInst::ICMP_UGT: {
2097     SCEVHandle TC = HowManyLessThans(SE.getNotSCEV(LHS),
2098                                      SE.getNotSCEV(RHS), L, false, false);
2099     if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2100     break;
2101   }
2102   case ICmpInst::ICMP_SLE: {
2103     SCEVHandle TC = HowManyLessThans(LHS, RHS, L, true, true);
2104     if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2105     break;
2106   }
2107   case ICmpInst::ICMP_SGE: {
2108     SCEVHandle TC = HowManyLessThans(SE.getNotSCEV(LHS),
2109                                      SE.getNotSCEV(RHS), L, true, true);
2110     if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2111     break;
2112   }
2113   case ICmpInst::ICMP_ULE: {
2114     SCEVHandle TC = HowManyLessThans(LHS, RHS, L, false, true);
2115     if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2116     break;
2117   }
2118   case ICmpInst::ICMP_UGE: {
2119     SCEVHandle TC = HowManyLessThans(SE.getNotSCEV(LHS),
2120                                      SE.getNotSCEV(RHS), L, false, true);
2121     if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2122     break;
2123   }
2124   default:
2125 #if 0
2126     cerr << "ComputeIterationCount ";
2127     if (ExitCond->getOperand(0)->getType()->isUnsigned())
2128       cerr << "[unsigned] ";
2129     cerr << *LHS << "   "
2130          << Instruction::getOpcodeName(Instruction::ICmp) 
2131          << "   " << *RHS << "\n";
2132 #endif
2133     break;
2134   }
2135   return ComputeIterationCountExhaustively(L, ExitCond,
2136                                        ExitBr->getSuccessor(0) == ExitBlock);
2137 }
2138
2139 static ConstantInt *
2140 EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
2141                                 ScalarEvolution &SE) {
2142   SCEVHandle InVal = SE.getConstant(C);
2143   SCEVHandle Val = AddRec->evaluateAtIteration(InVal, SE);
2144   assert(isa<SCEVConstant>(Val) &&
2145          "Evaluation of SCEV at constant didn't fold correctly?");
2146   return cast<SCEVConstant>(Val)->getValue();
2147 }
2148
2149 /// GetAddressedElementFromGlobal - Given a global variable with an initializer
2150 /// and a GEP expression (missing the pointer index) indexing into it, return
2151 /// the addressed element of the initializer or null if the index expression is
2152 /// invalid.
2153 static Constant *
2154 GetAddressedElementFromGlobal(GlobalVariable *GV,
2155                               const std::vector<ConstantInt*> &Indices) {
2156   Constant *Init = GV->getInitializer();
2157   for (unsigned i = 0, e = Indices.size(); i != e; ++i) {
2158     uint64_t Idx = Indices[i]->getZExtValue();
2159     if (ConstantStruct *CS = dyn_cast<ConstantStruct>(Init)) {
2160       assert(Idx < CS->getNumOperands() && "Bad struct index!");
2161       Init = cast<Constant>(CS->getOperand(Idx));
2162     } else if (ConstantArray *CA = dyn_cast<ConstantArray>(Init)) {
2163       if (Idx >= CA->getNumOperands()) return 0;  // Bogus program
2164       Init = cast<Constant>(CA->getOperand(Idx));
2165     } else if (isa<ConstantAggregateZero>(Init)) {
2166       if (const StructType *STy = dyn_cast<StructType>(Init->getType())) {
2167         assert(Idx < STy->getNumElements() && "Bad struct index!");
2168         Init = Constant::getNullValue(STy->getElementType(Idx));
2169       } else if (const ArrayType *ATy = dyn_cast<ArrayType>(Init->getType())) {
2170         if (Idx >= ATy->getNumElements()) return 0;  // Bogus program
2171         Init = Constant::getNullValue(ATy->getElementType());
2172       } else {
2173         assert(0 && "Unknown constant aggregate type!");
2174       }
2175       return 0;
2176     } else {
2177       return 0; // Unknown initializer type
2178     }
2179   }
2180   return Init;
2181 }
2182
2183 /// ComputeLoadConstantCompareIterationCount - Given an exit condition of
2184 /// 'icmp op load X, cst', try to see if we can compute the trip count.
2185 SCEVHandle ScalarEvolutionsImpl::
2186 ComputeLoadConstantCompareIterationCount(LoadInst *LI, Constant *RHS,
2187                                          const Loop *L, 
2188                                          ICmpInst::Predicate predicate) {
2189   if (LI->isVolatile()) return UnknownValue;
2190
2191   // Check to see if the loaded pointer is a getelementptr of a global.
2192   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0));
2193   if (!GEP) return UnknownValue;
2194
2195   // Make sure that it is really a constant global we are gepping, with an
2196   // initializer, and make sure the first IDX is really 0.
2197   GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
2198   if (!GV || !GV->isConstant() || !GV->hasInitializer() ||
2199       GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) ||
2200       !cast<Constant>(GEP->getOperand(1))->isNullValue())
2201     return UnknownValue;
2202
2203   // Okay, we allow one non-constant index into the GEP instruction.
2204   Value *VarIdx = 0;
2205   std::vector<ConstantInt*> Indexes;
2206   unsigned VarIdxNum = 0;
2207   for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i)
2208     if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
2209       Indexes.push_back(CI);
2210     } else if (!isa<ConstantInt>(GEP->getOperand(i))) {
2211       if (VarIdx) return UnknownValue;  // Multiple non-constant idx's.
2212       VarIdx = GEP->getOperand(i);
2213       VarIdxNum = i-2;
2214       Indexes.push_back(0);
2215     }
2216
2217   // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant.
2218   // Check to see if X is a loop variant variable value now.
2219   SCEVHandle Idx = getSCEV(VarIdx);
2220   SCEVHandle Tmp = getSCEVAtScope(Idx, L);
2221   if (!isa<SCEVCouldNotCompute>(Tmp)) Idx = Tmp;
2222
2223   // We can only recognize very limited forms of loop index expressions, in
2224   // particular, only affine AddRec's like {C1,+,C2}.
2225   SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx);
2226   if (!IdxExpr || !IdxExpr->isAffine() || IdxExpr->isLoopInvariant(L) ||
2227       !isa<SCEVConstant>(IdxExpr->getOperand(0)) ||
2228       !isa<SCEVConstant>(IdxExpr->getOperand(1)))
2229     return UnknownValue;
2230
2231   unsigned MaxSteps = MaxBruteForceIterations;
2232   for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) {
2233     ConstantInt *ItCst =
2234       ConstantInt::get(IdxExpr->getType(), IterationNum);
2235     ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, SE);
2236
2237     // Form the GEP offset.
2238     Indexes[VarIdxNum] = Val;
2239
2240     Constant *Result = GetAddressedElementFromGlobal(GV, Indexes);
2241     if (Result == 0) break;  // Cannot compute!
2242
2243     // Evaluate the condition for this iteration.
2244     Result = ConstantExpr::getICmp(predicate, Result, RHS);
2245     if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure
2246     if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
2247 #if 0
2248       cerr << "\n***\n*** Computed loop count " << *ItCst
2249            << "\n*** From global " << *GV << "*** BB: " << *L->getHeader()
2250            << "***\n";
2251 #endif
2252       ++NumArrayLenItCounts;
2253       return SE.getConstant(ItCst);   // Found terminating iteration!
2254     }
2255   }
2256   return UnknownValue;
2257 }
2258
2259
2260 /// CanConstantFold - Return true if we can constant fold an instruction of the
2261 /// specified type, assuming that all operands were constants.
2262 static bool CanConstantFold(const Instruction *I) {
2263   if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
2264       isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I))
2265     return true;
2266
2267   if (const CallInst *CI = dyn_cast<CallInst>(I))
2268     if (const Function *F = CI->getCalledFunction())
2269       return canConstantFoldCallTo(F);
2270   return false;
2271 }
2272
2273 /// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
2274 /// in the loop that V is derived from.  We allow arbitrary operations along the
2275 /// way, but the operands of an operation must either be constants or a value
2276 /// derived from a constant PHI.  If this expression does not fit with these
2277 /// constraints, return null.
2278 static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
2279   // If this is not an instruction, or if this is an instruction outside of the
2280   // loop, it can't be derived from a loop PHI.
2281   Instruction *I = dyn_cast<Instruction>(V);
2282   if (I == 0 || !L->contains(I->getParent())) return 0;
2283
2284   if (PHINode *PN = dyn_cast<PHINode>(I)) {
2285     if (L->getHeader() == I->getParent())
2286       return PN;
2287     else
2288       // We don't currently keep track of the control flow needed to evaluate
2289       // PHIs, so we cannot handle PHIs inside of loops.
2290       return 0;
2291   }
2292
2293   // If we won't be able to constant fold this expression even if the operands
2294   // are constants, return early.
2295   if (!CanConstantFold(I)) return 0;
2296
2297   // Otherwise, we can evaluate this instruction if all of its operands are
2298   // constant or derived from a PHI node themselves.
2299   PHINode *PHI = 0;
2300   for (unsigned Op = 0, e = I->getNumOperands(); Op != e; ++Op)
2301     if (!(isa<Constant>(I->getOperand(Op)) ||
2302           isa<GlobalValue>(I->getOperand(Op)))) {
2303       PHINode *P = getConstantEvolvingPHI(I->getOperand(Op), L);
2304       if (P == 0) return 0;  // Not evolving from PHI
2305       if (PHI == 0)
2306         PHI = P;
2307       else if (PHI != P)
2308         return 0;  // Evolving from multiple different PHIs.
2309     }
2310
2311   // This is a expression evolving from a constant PHI!
2312   return PHI;
2313 }
2314
2315 /// EvaluateExpression - Given an expression that passes the
2316 /// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
2317 /// in the loop has the value PHIVal.  If we can't fold this expression for some
2318 /// reason, return null.
2319 static Constant *EvaluateExpression(Value *V, Constant *PHIVal) {
2320   if (isa<PHINode>(V)) return PHIVal;
2321   if (Constant *C = dyn_cast<Constant>(V)) return C;
2322   Instruction *I = cast<Instruction>(V);
2323
2324   std::vector<Constant*> Operands;
2325   Operands.resize(I->getNumOperands());
2326
2327   for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
2328     Operands[i] = EvaluateExpression(I->getOperand(i), PHIVal);
2329     if (Operands[i] == 0) return 0;
2330   }
2331
2332   if (const CmpInst *CI = dyn_cast<CmpInst>(I))
2333     return ConstantFoldCompareInstOperands(CI->getPredicate(),
2334                                            &Operands[0], Operands.size());
2335   else
2336     return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
2337                                     &Operands[0], Operands.size());
2338 }
2339
2340 /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
2341 /// in the header of its containing loop, we know the loop executes a
2342 /// constant number of times, and the PHI node is just a recurrence
2343 /// involving constants, fold it.
2344 Constant *ScalarEvolutionsImpl::
2345 getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& Its, const Loop *L){
2346   std::map<PHINode*, Constant*>::iterator I =
2347     ConstantEvolutionLoopExitValue.find(PN);
2348   if (I != ConstantEvolutionLoopExitValue.end())
2349     return I->second;
2350
2351   if (Its.ugt(APInt(Its.getBitWidth(),MaxBruteForceIterations)))
2352     return ConstantEvolutionLoopExitValue[PN] = 0;  // Not going to evaluate it.
2353
2354   Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
2355
2356   // Since the loop is canonicalized, the PHI node must have two entries.  One
2357   // entry must be a constant (coming in from outside of the loop), and the
2358   // second must be derived from the same PHI.
2359   bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
2360   Constant *StartCST =
2361     dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
2362   if (StartCST == 0)
2363     return RetVal = 0;  // Must be a constant.
2364
2365   Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
2366   PHINode *PN2 = getConstantEvolvingPHI(BEValue, L);
2367   if (PN2 != PN)
2368     return RetVal = 0;  // Not derived from same PHI.
2369
2370   // Execute the loop symbolically to determine the exit value.
2371   if (Its.getActiveBits() >= 32)
2372     return RetVal = 0; // More than 2^32-1 iterations?? Not doing it!
2373
2374   unsigned NumIterations = Its.getZExtValue(); // must be in range
2375   unsigned IterationNum = 0;
2376   for (Constant *PHIVal = StartCST; ; ++IterationNum) {
2377     if (IterationNum == NumIterations)
2378       return RetVal = PHIVal;  // Got exit value!
2379
2380     // Compute the value of the PHI node for the next iteration.
2381     Constant *NextPHI = EvaluateExpression(BEValue, PHIVal);
2382     if (NextPHI == PHIVal)
2383       return RetVal = NextPHI;  // Stopped evolving!
2384     if (NextPHI == 0)
2385       return 0;        // Couldn't evaluate!
2386     PHIVal = NextPHI;
2387   }
2388 }
2389
2390 /// ComputeIterationCountExhaustively - If the trip is known to execute a
2391 /// constant number of times (the condition evolves only from constants),
2392 /// try to evaluate a few iterations of the loop until we get the exit
2393 /// condition gets a value of ExitWhen (true or false).  If we cannot
2394 /// evaluate the trip count of the loop, return UnknownValue.
2395 SCEVHandle ScalarEvolutionsImpl::
2396 ComputeIterationCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) {
2397   PHINode *PN = getConstantEvolvingPHI(Cond, L);
2398   if (PN == 0) return UnknownValue;
2399
2400   // Since the loop is canonicalized, the PHI node must have two entries.  One
2401   // entry must be a constant (coming in from outside of the loop), and the
2402   // second must be derived from the same PHI.
2403   bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
2404   Constant *StartCST =
2405     dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
2406   if (StartCST == 0) return UnknownValue;  // Must be a constant.
2407
2408   Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
2409   PHINode *PN2 = getConstantEvolvingPHI(BEValue, L);
2410   if (PN2 != PN) return UnknownValue;  // Not derived from same PHI.
2411
2412   // Okay, we find a PHI node that defines the trip count of this loop.  Execute
2413   // the loop symbolically to determine when the condition gets a value of
2414   // "ExitWhen".
2415   unsigned IterationNum = 0;
2416   unsigned MaxIterations = MaxBruteForceIterations;   // Limit analysis.
2417   for (Constant *PHIVal = StartCST;
2418        IterationNum != MaxIterations; ++IterationNum) {
2419     ConstantInt *CondVal =
2420       dyn_cast_or_null<ConstantInt>(EvaluateExpression(Cond, PHIVal));
2421
2422     // Couldn't symbolically evaluate.
2423     if (!CondVal) return UnknownValue;
2424
2425     if (CondVal->getValue() == uint64_t(ExitWhen)) {
2426       ConstantEvolutionLoopExitValue[PN] = PHIVal;
2427       ++NumBruteForceTripCountsComputed;
2428       return SE.getConstant(ConstantInt::get(Type::Int32Ty, IterationNum));
2429     }
2430
2431     // Compute the value of the PHI node for the next iteration.
2432     Constant *NextPHI = EvaluateExpression(BEValue, PHIVal);
2433     if (NextPHI == 0 || NextPHI == PHIVal)
2434       return UnknownValue;  // Couldn't evaluate or not making progress...
2435     PHIVal = NextPHI;
2436   }
2437
2438   // Too many iterations were needed to evaluate.
2439   return UnknownValue;
2440 }
2441
2442 /// getSCEVAtScope - Compute the value of the specified expression within the
2443 /// indicated loop (which may be null to indicate in no loop).  If the
2444 /// expression cannot be evaluated, return UnknownValue.
2445 SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) {
2446   // FIXME: this should be turned into a virtual method on SCEV!
2447
2448   if (isa<SCEVConstant>(V)) return V;
2449
2450   // If this instruction is evolved from a constant-evolving PHI, compute the
2451   // exit value from the loop without using SCEVs.
2452   if (SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
2453     if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
2454       const Loop *LI = this->LI[I->getParent()];
2455       if (LI && LI->getParentLoop() == L)  // Looking for loop exit value.
2456         if (PHINode *PN = dyn_cast<PHINode>(I))
2457           if (PN->getParent() == LI->getHeader()) {
2458             // Okay, there is no closed form solution for the PHI node.  Check
2459             // to see if the loop that contains it has a known iteration count.
2460             // If so, we may be able to force computation of the exit value.
2461             SCEVHandle IterationCount = getIterationCount(LI);
2462             if (SCEVConstant *ICC = dyn_cast<SCEVConstant>(IterationCount)) {
2463               // Okay, we know how many times the containing loop executes.  If
2464               // this is a constant evolving PHI node, get the final value at
2465               // the specified iteration number.
2466               Constant *RV = getConstantEvolutionLoopExitValue(PN,
2467                                                     ICC->getValue()->getValue(),
2468                                                                LI);
2469               if (RV) return SE.getUnknown(RV);
2470             }
2471           }
2472
2473       // Okay, this is an expression that we cannot symbolically evaluate
2474       // into a SCEV.  Check to see if it's possible to symbolically evaluate
2475       // the arguments into constants, and if so, try to constant propagate the
2476       // result.  This is particularly useful for computing loop exit values.
2477       if (CanConstantFold(I)) {
2478         std::vector<Constant*> Operands;
2479         Operands.reserve(I->getNumOperands());
2480         for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
2481           Value *Op = I->getOperand(i);
2482           if (Constant *C = dyn_cast<Constant>(Op)) {
2483             Operands.push_back(C);
2484           } else {
2485             // If any of the operands is non-constant and if they are
2486             // non-integer, don't even try to analyze them with scev techniques.
2487             if (!isa<IntegerType>(Op->getType()))
2488               return V;
2489               
2490             SCEVHandle OpV = getSCEVAtScope(getSCEV(Op), L);
2491             if (SCEVConstant *SC = dyn_cast<SCEVConstant>(OpV))
2492               Operands.push_back(ConstantExpr::getIntegerCast(SC->getValue(), 
2493                                                               Op->getType(), 
2494                                                               false));
2495             else if (SCEVUnknown *SU = dyn_cast<SCEVUnknown>(OpV)) {
2496               if (Constant *C = dyn_cast<Constant>(SU->getValue()))
2497                 Operands.push_back(ConstantExpr::getIntegerCast(C, 
2498                                                                 Op->getType(), 
2499                                                                 false));
2500               else
2501                 return V;
2502             } else {
2503               return V;
2504             }
2505           }
2506         }
2507         
2508         Constant *C;
2509         if (const CmpInst *CI = dyn_cast<CmpInst>(I))
2510           C = ConstantFoldCompareInstOperands(CI->getPredicate(),
2511                                               &Operands[0], Operands.size());
2512         else
2513           C = ConstantFoldInstOperands(I->getOpcode(), I->getType(),
2514                                        &Operands[0], Operands.size());
2515         return SE.getUnknown(C);
2516       }
2517     }
2518
2519     // This is some other type of SCEVUnknown, just return it.
2520     return V;
2521   }
2522
2523   if (SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) {
2524     // Avoid performing the look-up in the common case where the specified
2525     // expression has no loop-variant portions.
2526     for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
2527       SCEVHandle OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
2528       if (OpAtScope != Comm->getOperand(i)) {
2529         if (OpAtScope == UnknownValue) return UnknownValue;
2530         // Okay, at least one of these operands is loop variant but might be
2531         // foldable.  Build a new instance of the folded commutative expression.
2532         std::vector<SCEVHandle> NewOps(Comm->op_begin(), Comm->op_begin()+i);
2533         NewOps.push_back(OpAtScope);
2534
2535         for (++i; i != e; ++i) {
2536           OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
2537           if (OpAtScope == UnknownValue) return UnknownValue;
2538           NewOps.push_back(OpAtScope);
2539         }
2540         if (isa<SCEVAddExpr>(Comm))
2541           return SE.getAddExpr(NewOps);
2542         if (isa<SCEVMulExpr>(Comm))
2543           return SE.getMulExpr(NewOps);
2544         if (isa<SCEVSMaxExpr>(Comm))
2545           return SE.getSMaxExpr(NewOps);
2546         if (isa<SCEVUMaxExpr>(Comm))
2547           return SE.getUMaxExpr(NewOps);
2548         assert(0 && "Unknown commutative SCEV type!");
2549       }
2550     }
2551     // If we got here, all operands are loop invariant.
2552     return Comm;
2553   }
2554
2555   if (SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(V)) {
2556     SCEVHandle LHS = getSCEVAtScope(UDiv->getLHS(), L);
2557     if (LHS == UnknownValue) return LHS;
2558     SCEVHandle RHS = getSCEVAtScope(UDiv->getRHS(), L);
2559     if (RHS == UnknownValue) return RHS;
2560     if (LHS == UDiv->getLHS() && RHS == UDiv->getRHS())
2561       return UDiv;   // must be loop invariant
2562     return SE.getUDivExpr(LHS, RHS);
2563   }
2564
2565   if (SCEVSDivExpr *SDiv = dyn_cast<SCEVSDivExpr>(V)) {
2566     SCEVHandle LHS = getSCEVAtScope(SDiv->getLHS(), L);
2567     if (LHS == UnknownValue) return LHS;
2568     SCEVHandle RHS = getSCEVAtScope(SDiv->getRHS(), L);
2569     if (RHS == UnknownValue) return RHS;
2570     if (LHS == SDiv->getLHS() && RHS == SDiv->getRHS())
2571       return SDiv;   // must be loop invariant
2572     return SE.getSDivExpr(LHS, RHS);
2573   }
2574
2575   // If this is a loop recurrence for a loop that does not contain L, then we
2576   // are dealing with the final value computed by the loop.
2577   if (SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
2578     if (!L || !AddRec->getLoop()->contains(L->getHeader())) {
2579       // To evaluate this recurrence, we need to know how many times the AddRec
2580       // loop iterates.  Compute this now.
2581       SCEVHandle IterationCount = getIterationCount(AddRec->getLoop());
2582       if (IterationCount == UnknownValue) return UnknownValue;
2583
2584       // Then, evaluate the AddRec.
2585       return AddRec->evaluateAtIteration(IterationCount, SE);
2586     }
2587     return UnknownValue;
2588   }
2589
2590   //assert(0 && "Unknown SCEV type!");
2591   return UnknownValue;
2592 }
2593
2594 /// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the
2595 /// following equation:
2596 ///
2597 ///     A * X = B (mod N)
2598 ///
2599 /// where N = 2^BW and BW is the common bit width of A and B. The signedness of
2600 /// A and B isn't important.
2601 ///
2602 /// If the equation does not have a solution, SCEVCouldNotCompute is returned.
2603 static SCEVHandle SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
2604                                                ScalarEvolution &SE) {
2605   uint32_t BW = A.getBitWidth();
2606   assert(BW == B.getBitWidth() && "Bit widths must be the same.");
2607   assert(A != 0 && "A must be non-zero.");
2608
2609   // 1. D = gcd(A, N)
2610   //
2611   // The gcd of A and N may have only one prime factor: 2. The number of
2612   // trailing zeros in A is its multiplicity
2613   uint32_t Mult2 = A.countTrailingZeros();
2614   // D = 2^Mult2
2615
2616   // 2. Check if B is divisible by D.
2617   //
2618   // B is divisible by D if and only if the multiplicity of prime factor 2 for B
2619   // is not less than multiplicity of this prime factor for D.
2620   if (B.countTrailingZeros() < Mult2)
2621     return new SCEVCouldNotCompute();
2622
2623   // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
2624   // modulo (N / D).
2625   //
2626   // (N / D) may need BW+1 bits in its representation.  Hence, we'll use this
2627   // bit width during computations.
2628   APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
2629   APInt Mod(BW + 1, 0);
2630   Mod.set(BW - Mult2);  // Mod = N / D
2631   APInt I = AD.multiplicativeInverse(Mod);
2632
2633   // 4. Compute the minimum unsigned root of the equation:
2634   // I * (B / D) mod (N / D)
2635   APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod);
2636
2637   // The result is guaranteed to be less than 2^BW so we may truncate it to BW
2638   // bits.
2639   return SE.getConstant(Result.trunc(BW));
2640 }
2641
2642 /// SolveQuadraticEquation - Find the roots of the quadratic equation for the
2643 /// given quadratic chrec {L,+,M,+,N}.  This returns either the two roots (which
2644 /// might be the same) or two SCEVCouldNotCompute objects.
2645 ///
2646 static std::pair<SCEVHandle,SCEVHandle>
2647 SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
2648   assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
2649   SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
2650   SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
2651   SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
2652
2653   // We currently can only solve this if the coefficients are constants.
2654   if (!LC || !MC || !NC) {
2655     SCEV *CNC = new SCEVCouldNotCompute();
2656     return std::make_pair(CNC, CNC);
2657   }
2658
2659   uint32_t BitWidth = LC->getValue()->getValue().getBitWidth();
2660   const APInt &L = LC->getValue()->getValue();
2661   const APInt &M = MC->getValue()->getValue();
2662   const APInt &N = NC->getValue()->getValue();
2663   APInt Two(BitWidth, 2);
2664   APInt Four(BitWidth, 4);
2665
2666   { 
2667     using namespace APIntOps;
2668     const APInt& C = L;
2669     // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C
2670     // The B coefficient is M-N/2
2671     APInt B(M);
2672     B -= sdiv(N,Two);
2673
2674     // The A coefficient is N/2
2675     APInt A(N.sdiv(Two));
2676
2677     // Compute the B^2-4ac term.
2678     APInt SqrtTerm(B);
2679     SqrtTerm *= B;
2680     SqrtTerm -= Four * (A * C);
2681
2682     // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest
2683     // integer value or else APInt::sqrt() will assert.
2684     APInt SqrtVal(SqrtTerm.sqrt());
2685
2686     // Compute the two solutions for the quadratic formula. 
2687     // The divisions must be performed as signed divisions.
2688     APInt NegB(-B);
2689     APInt TwoA( A << 1 );
2690     if (TwoA.isMinValue()) {
2691       SCEV *CNC = new SCEVCouldNotCompute();
2692       return std::make_pair(CNC, CNC);
2693     }
2694
2695     ConstantInt *Solution1 = ConstantInt::get((NegB + SqrtVal).sdiv(TwoA));
2696     ConstantInt *Solution2 = ConstantInt::get((NegB - SqrtVal).sdiv(TwoA));
2697
2698     return std::make_pair(SE.getConstant(Solution1), 
2699                           SE.getConstant(Solution2));
2700     } // end APIntOps namespace
2701 }
2702
2703 /// HowFarToZero - Return the number of times a backedge comparing the specified
2704 /// value to zero will execute.  If not computable, return UnknownValue
2705 SCEVHandle ScalarEvolutionsImpl::HowFarToZero(SCEV *V, const Loop *L) {
2706   // If the value is a constant
2707   if (SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
2708     // If the value is already zero, the branch will execute zero times.
2709     if (C->getValue()->isZero()) return C;
2710     return UnknownValue;  // Otherwise it will loop infinitely.
2711   }
2712
2713   SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V);
2714   if (!AddRec || AddRec->getLoop() != L)
2715     return UnknownValue;
2716
2717   if (AddRec->isAffine()) {
2718     // If this is an affine expression, the execution count of this branch is
2719     // the minimum unsigned root of the following equation:
2720     //
2721     //     Start + Step*N = 0 (mod 2^BW)
2722     //
2723     // equivalent to:
2724     //
2725     //             Step*N = -Start (mod 2^BW)
2726     //
2727     // where BW is the common bit width of Start and Step.
2728
2729     // Get the initial value for the loop.
2730     SCEVHandle Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
2731     if (isa<SCEVCouldNotCompute>(Start)) return UnknownValue;
2732
2733     SCEVHandle Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
2734
2735     if (SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step)) {
2736       // For now we handle only constant steps.
2737
2738       // First, handle unitary steps.
2739       if (StepC->getValue()->equalsInt(1))      // 1*N = -Start (mod 2^BW), so:
2740         return SE.getNegativeSCEV(Start);       //   N = -Start (as unsigned)
2741       if (StepC->getValue()->isAllOnesValue())  // -1*N = -Start (mod 2^BW), so:
2742         return Start;                           //    N = Start (as unsigned)
2743
2744       // Then, try to solve the above equation provided that Start is constant.
2745       if (SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
2746         return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
2747                                             -StartC->getValue()->getValue(),SE);
2748     }
2749   } else if (AddRec->isQuadratic() && AddRec->getType()->isInteger()) {
2750     // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
2751     // the quadratic equation to solve it.
2752     std::pair<SCEVHandle,SCEVHandle> Roots = SolveQuadraticEquation(AddRec, SE);
2753     SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
2754     SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
2755     if (R1) {
2756 #if 0
2757       cerr << "HFTZ: " << *V << " - sol#1: " << *R1
2758            << "  sol#2: " << *R2 << "\n";
2759 #endif
2760       // Pick the smallest positive root value.
2761       if (ConstantInt *CB =
2762           dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, 
2763                                    R1->getValue(), R2->getValue()))) {
2764         if (CB->getZExtValue() == false)
2765           std::swap(R1, R2);   // R1 is the minimum root now.
2766
2767         // We can only use this value if the chrec ends up with an exact zero
2768         // value at this index.  When solving for "X*X != 5", for example, we
2769         // should not accept a root of 2.
2770         SCEVHandle Val = AddRec->evaluateAtIteration(R1, SE);
2771         if (Val->isZero())
2772           return R1;  // We found a quadratic root!
2773       }
2774     }
2775   }
2776
2777   return UnknownValue;
2778 }
2779
2780 /// HowFarToNonZero - Return the number of times a backedge checking the
2781 /// specified value for nonzero will execute.  If not computable, return
2782 /// UnknownValue
2783 SCEVHandle ScalarEvolutionsImpl::HowFarToNonZero(SCEV *V, const Loop *L) {
2784   // Loops that look like: while (X == 0) are very strange indeed.  We don't
2785   // handle them yet except for the trivial case.  This could be expanded in the
2786   // future as needed.
2787
2788   // If the value is a constant, check to see if it is known to be non-zero
2789   // already.  If so, the backedge will execute zero times.
2790   if (SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
2791     if (!C->getValue()->isNullValue())
2792       return SE.getIntegerSCEV(0, C->getType());
2793     return UnknownValue;  // Otherwise it will loop infinitely.
2794   }
2795
2796   // We could implement others, but I really doubt anyone writes loops like
2797   // this, and if they did, they would already be constant folded.
2798   return UnknownValue;
2799 }
2800
2801 /// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB
2802 /// (which may not be an immediate predecessor) which has exactly one
2803 /// successor from which BB is reachable, or null if no such block is
2804 /// found.
2805 ///
2806 BasicBlock *
2807 ScalarEvolutionsImpl::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
2808   // If the block has a unique predecessor, the predecessor must have
2809   // no other successors from which BB is reachable.
2810   if (BasicBlock *Pred = BB->getSinglePredecessor())
2811     return Pred;
2812
2813   // A loop's header is defined to be a block that dominates the loop.
2814   // If the loop has a preheader, it must be a block that has exactly
2815   // one successor that can reach BB. This is slightly more strict
2816   // than necessary, but works if critical edges are split.
2817   if (Loop *L = LI.getLoopFor(BB))
2818     return L->getLoopPreheader();
2819
2820   return 0;
2821 }
2822
2823 /// executesAtLeastOnce - Test whether entry to the loop is protected by
2824 /// a conditional between LHS and RHS.
2825 bool ScalarEvolutionsImpl::executesAtLeastOnce(const Loop *L, bool isSigned,
2826                                                bool trueWhenEqual,
2827                                                SCEV *LHS, SCEV *RHS) {
2828   BasicBlock *Preheader = L->getLoopPreheader();
2829   BasicBlock *PreheaderDest = L->getHeader();
2830
2831   // Starting at the preheader, climb up the predecessor chain, as long as
2832   // there are predecessors that can be found that have unique successors
2833   // leading to the original header.
2834   for (; Preheader;
2835        PreheaderDest = Preheader,
2836        Preheader = getPredecessorWithUniqueSuccessorForBB(Preheader)) {
2837
2838     BranchInst *LoopEntryPredicate =
2839       dyn_cast<BranchInst>(Preheader->getTerminator());
2840     if (!LoopEntryPredicate ||
2841         LoopEntryPredicate->isUnconditional())
2842       continue;
2843
2844     ICmpInst *ICI = dyn_cast<ICmpInst>(LoopEntryPredicate->getCondition());
2845     if (!ICI) continue;
2846
2847     // Now that we found a conditional branch that dominates the loop, check to
2848     // see if it is the comparison we are looking for.
2849     Value *PreCondLHS = ICI->getOperand(0);
2850     Value *PreCondRHS = ICI->getOperand(1);
2851     ICmpInst::Predicate Cond;
2852     if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest)
2853       Cond = ICI->getPredicate();
2854     else
2855       Cond = ICI->getInversePredicate();
2856
2857     switch (Cond) {
2858     case ICmpInst::ICMP_UGT:
2859       if (isSigned || trueWhenEqual) continue;
2860       std::swap(PreCondLHS, PreCondRHS);
2861       Cond = ICmpInst::ICMP_ULT;
2862       break;
2863     case ICmpInst::ICMP_SGT:
2864       if (!isSigned || trueWhenEqual) continue;
2865       std::swap(PreCondLHS, PreCondRHS);
2866       Cond = ICmpInst::ICMP_SLT;
2867       break;
2868     case ICmpInst::ICMP_ULT:
2869       if (isSigned || trueWhenEqual) continue;
2870       break;
2871     case ICmpInst::ICMP_SLT:
2872       if (!isSigned || trueWhenEqual) continue;
2873       break;
2874     case ICmpInst::ICMP_UGE:
2875       if (isSigned || !trueWhenEqual) continue;
2876       std::swap(PreCondLHS, PreCondRHS);
2877       Cond = ICmpInst::ICMP_ULE;
2878       break;
2879     case ICmpInst::ICMP_SGE:
2880       if (!isSigned || !trueWhenEqual) continue;
2881       std::swap(PreCondLHS, PreCondRHS);
2882       Cond = ICmpInst::ICMP_SLE;
2883       break;
2884     case ICmpInst::ICMP_ULE:
2885       if (isSigned || !trueWhenEqual) continue;
2886       break;
2887     case ICmpInst::ICMP_SLE:
2888       if (!isSigned || !trueWhenEqual) continue;
2889       break;
2890     default:
2891       continue;
2892     }
2893
2894     if (!PreCondLHS->getType()->isInteger()) continue;
2895
2896     SCEVHandle PreCondLHSSCEV = getSCEV(PreCondLHS);
2897     SCEVHandle PreCondRHSSCEV = getSCEV(PreCondRHS);
2898     if ((LHS == PreCondLHSSCEV && RHS == PreCondRHSSCEV) ||
2899         (LHS == SE.getNotSCEV(PreCondRHSSCEV) &&
2900          RHS == SE.getNotSCEV(PreCondLHSSCEV)))
2901       return true;
2902   }
2903
2904   return false;
2905 }
2906
2907 /// potentialInfiniteLoop - Test whether the loop might jump over the exit value
2908 /// due to wrapping around 2^n.
2909 bool ScalarEvolutionsImpl::potentialInfiniteLoop(SCEV *Stride, SCEV *RHS,
2910                                                  bool isSigned, bool trueWhenEqual) {
2911   // Return true when the distance from RHS to maxint > Stride.
2912
2913   if (!isa<SCEVConstant>(Stride))
2914     return true;
2915   SCEVConstant *SC = cast<SCEVConstant>(Stride);
2916
2917   if (SC->getValue()->isZero())
2918     return true;
2919   if (!trueWhenEqual && SC->getValue()->isOne())
2920     return false;
2921
2922   if (!isa<SCEVConstant>(RHS))
2923     return true;
2924   SCEVConstant *R = cast<SCEVConstant>(RHS);
2925
2926   if (isSigned)
2927     return true;  // XXX: because we don't have an sdiv scev.
2928
2929   // If negative, it wraps around every iteration, but we don't care about that.
2930   APInt S = SC->getValue()->getValue().abs();
2931
2932   APInt Dist = APInt::getMaxValue(R->getValue()->getBitWidth()) -
2933                R->getValue()->getValue();
2934
2935   if (trueWhenEqual)
2936     return !S.ult(Dist);
2937   else
2938     return !S.ule(Dist);
2939 }
2940
2941 /// HowManyLessThans - Return the number of times a backedge containing the
2942 /// specified less-than comparison will execute.  If not computable, return
2943 /// UnknownValue.
2944 SCEVHandle ScalarEvolutionsImpl::
2945 HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L,
2946                  bool isSigned, bool trueWhenEqual) {
2947   // Only handle:  "ADDREC < LoopInvariant".
2948   if (!RHS->isLoopInvariant(L)) return UnknownValue;
2949
2950   SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS);
2951   if (!AddRec || AddRec->getLoop() != L)
2952     return UnknownValue;
2953
2954   if (AddRec->isAffine()) {
2955     SCEVHandle Stride = AddRec->getOperand(1);
2956     if (potentialInfiniteLoop(Stride, RHS, isSigned, trueWhenEqual))
2957       return UnknownValue;
2958
2959     // We know the LHS is of the form {n,+,s} and the RHS is some loop-invariant
2960     // m.  So, we count the number of iterations in which {n,+,s} < m is true.
2961     // Note that we cannot simply return max(m-n,0)/s because it's not safe to
2962     // treat m-n as signed nor unsigned due to overflow possibility.
2963
2964     // First, we get the value of the LHS in the first iteration: n
2965     SCEVHandle Start = AddRec->getOperand(0);
2966
2967     SCEVHandle One = SE.getIntegerSCEV(1, RHS->getType());
2968
2969     // Assuming that the loop will run at least once, we know that it will
2970     // run (m-n)/s times.
2971     SCEVHandle End = RHS;
2972
2973     if (!executesAtLeastOnce(L, isSigned, trueWhenEqual,
2974                              SE.getMinusSCEV(Start, One), RHS)) {
2975       // If not, we get the value of the LHS in the first iteration in which
2976       // the above condition doesn't hold.  This equals to max(m,n).
2977       End = isSigned ? SE.getSMaxExpr(RHS, Start)
2978                      : SE.getUMaxExpr(RHS, Start);
2979     }
2980
2981     // If the expression is less-than-or-equal to, we need to extend the
2982     // loop by one iteration.
2983     //
2984     // The loop won't actually run (m-n)/s times because the loop iterations
2985     // won't divide evenly. For example, if you have {2,+,5} u< 10 the
2986     // division would equal one, but the loop runs twice putting the
2987     // induction variable at 12.
2988
2989     if (!trueWhenEqual)
2990       // (Stride - 1) is correct only because we know it's unsigned.
2991       // What we really want is to decrease the magnitude of Stride by one.
2992       Start = SE.getMinusSCEV(Start, SE.getMinusSCEV(Stride, One));
2993     else
2994       Start = SE.getMinusSCEV(Start, Stride);
2995
2996     // Finally, we subtract these two values to get the number of times the
2997     // backedge is executed: max(m,n)-n.
2998     return SE.getUDivExpr(SE.getMinusSCEV(End, Start), Stride);
2999   }
3000
3001   return UnknownValue;
3002 }
3003
3004 /// getNumIterationsInRange - Return the number of iterations of this loop that
3005 /// produce values in the specified constant range.  Another way of looking at
3006 /// this is that it returns the first iteration number where the value is not in
3007 /// the condition, thus computing the exit count. If the iteration count can't
3008 /// be computed, an instance of SCEVCouldNotCompute is returned.
3009 SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
3010                                                    ScalarEvolution &SE) const {
3011   if (Range.isFullSet())  // Infinite loop.
3012     return new SCEVCouldNotCompute();
3013
3014   // If the start is a non-zero constant, shift the range to simplify things.
3015   if (SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
3016     if (!SC->getValue()->isZero()) {
3017       std::vector<SCEVHandle> Operands(op_begin(), op_end());
3018       Operands[0] = SE.getIntegerSCEV(0, SC->getType());
3019       SCEVHandle Shifted = SE.getAddRecExpr(Operands, getLoop());
3020       if (SCEVAddRecExpr *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
3021         return ShiftedAddRec->getNumIterationsInRange(
3022                            Range.subtract(SC->getValue()->getValue()), SE);
3023       // This is strange and shouldn't happen.
3024       return new SCEVCouldNotCompute();
3025     }
3026
3027   // The only time we can solve this is when we have all constant indices.
3028   // Otherwise, we cannot determine the overflow conditions.
3029   for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
3030     if (!isa<SCEVConstant>(getOperand(i)))
3031       return new SCEVCouldNotCompute();
3032
3033
3034   // Okay at this point we know that all elements of the chrec are constants and
3035   // that the start element is zero.
3036
3037   // First check to see if the range contains zero.  If not, the first
3038   // iteration exits.
3039   if (!Range.contains(APInt(getBitWidth(),0))) 
3040     return SE.getConstant(ConstantInt::get(getType(),0));
3041
3042   if (isAffine()) {
3043     // If this is an affine expression then we have this situation:
3044     //   Solve {0,+,A} in Range  ===  Ax in Range
3045
3046     // We know that zero is in the range.  If A is positive then we know that
3047     // the upper value of the range must be the first possible exit value.
3048     // If A is negative then the lower of the range is the last possible loop
3049     // value.  Also note that we already checked for a full range.
3050     APInt One(getBitWidth(),1);
3051     APInt A     = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
3052     APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
3053
3054     // The exit value should be (End+A)/A.
3055     APInt ExitVal = (End + A).udiv(A);
3056     ConstantInt *ExitValue = ConstantInt::get(ExitVal);
3057
3058     // Evaluate at the exit value.  If we really did fall out of the valid
3059     // range, then we computed our trip count, otherwise wrap around or other
3060     // things must have happened.
3061     ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
3062     if (Range.contains(Val->getValue()))
3063       return new SCEVCouldNotCompute();  // Something strange happened
3064
3065     // Ensure that the previous value is in the range.  This is a sanity check.
3066     assert(Range.contains(
3067            EvaluateConstantChrecAtConstant(this, 
3068            ConstantInt::get(ExitVal - One), SE)->getValue()) &&
3069            "Linear scev computation is off in a bad way!");
3070     return SE.getConstant(ExitValue);
3071   } else if (isQuadratic()) {
3072     // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the
3073     // quadratic equation to solve it.  To do this, we must frame our problem in
3074     // terms of figuring out when zero is crossed, instead of when
3075     // Range.getUpper() is crossed.
3076     std::vector<SCEVHandle> NewOps(op_begin(), op_end());
3077     NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper()));
3078     SCEVHandle NewAddRec = SE.getAddRecExpr(NewOps, getLoop());
3079
3080     // Next, solve the constructed addrec
3081     std::pair<SCEVHandle,SCEVHandle> Roots =
3082       SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
3083     SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
3084     SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
3085     if (R1) {
3086       // Pick the smallest positive root value.
3087       if (ConstantInt *CB =
3088           dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, 
3089                                    R1->getValue(), R2->getValue()))) {
3090         if (CB->getZExtValue() == false)
3091           std::swap(R1, R2);   // R1 is the minimum root now.
3092
3093         // Make sure the root is not off by one.  The returned iteration should
3094         // not be in the range, but the previous one should be.  When solving
3095         // for "X*X < 5", for example, we should not return a root of 2.
3096         ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this,
3097                                                              R1->getValue(),
3098                                                              SE);
3099         if (Range.contains(R1Val->getValue())) {
3100           // The next iteration must be out of the range...
3101           ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()+1);
3102
3103           R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
3104           if (!Range.contains(R1Val->getValue()))
3105             return SE.getConstant(NextVal);
3106           return new SCEVCouldNotCompute();  // Something strange happened
3107         }
3108
3109         // If R1 was not in the range, then it is a good return value.  Make
3110         // sure that R1-1 WAS in the range though, just in case.
3111         ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()-1);
3112         R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
3113         if (Range.contains(R1Val->getValue()))
3114           return R1;
3115         return new SCEVCouldNotCompute();  // Something strange happened
3116       }
3117     }
3118   }
3119
3120   return new SCEVCouldNotCompute();
3121 }
3122
3123
3124
3125 //===----------------------------------------------------------------------===//
3126 //                   ScalarEvolution Class Implementation
3127 //===----------------------------------------------------------------------===//
3128
3129 bool ScalarEvolution::runOnFunction(Function &F) {
3130   Impl = new ScalarEvolutionsImpl(*this, F, getAnalysis<LoopInfo>());
3131   return false;
3132 }
3133
3134 void ScalarEvolution::releaseMemory() {
3135   delete (ScalarEvolutionsImpl*)Impl;
3136   Impl = 0;
3137 }
3138
3139 void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
3140   AU.setPreservesAll();
3141   AU.addRequiredTransitive<LoopInfo>();
3142 }
3143
3144 SCEVHandle ScalarEvolution::getSCEV(Value *V) const {
3145   return ((ScalarEvolutionsImpl*)Impl)->getSCEV(V);
3146 }
3147
3148 /// hasSCEV - Return true if the SCEV for this value has already been
3149 /// computed.
3150 bool ScalarEvolution::hasSCEV(Value *V) const {
3151   return ((ScalarEvolutionsImpl*)Impl)->hasSCEV(V);
3152 }
3153
3154
3155 /// setSCEV - Insert the specified SCEV into the map of current SCEVs for
3156 /// the specified value.
3157 void ScalarEvolution::setSCEV(Value *V, const SCEVHandle &H) {
3158   ((ScalarEvolutionsImpl*)Impl)->setSCEV(V, H);
3159 }
3160
3161
3162 SCEVHandle ScalarEvolution::getIterationCount(const Loop *L) const {
3163   return ((ScalarEvolutionsImpl*)Impl)->getIterationCount(L);
3164 }
3165
3166 bool ScalarEvolution::hasLoopInvariantIterationCount(const Loop *L) const {
3167   return !isa<SCEVCouldNotCompute>(getIterationCount(L));
3168 }
3169
3170 SCEVHandle ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) const {
3171   return ((ScalarEvolutionsImpl*)Impl)->getSCEVAtScope(getSCEV(V), L);
3172 }
3173
3174 void ScalarEvolution::deleteValueFromRecords(Value *V) const {
3175   return ((ScalarEvolutionsImpl*)Impl)->deleteValueFromRecords(V);
3176 }
3177
3178 static void PrintLoopInfo(std::ostream &OS, const ScalarEvolution *SE,
3179                           const Loop *L) {
3180   // Print all inner loops first
3181   for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
3182     PrintLoopInfo(OS, SE, *I);
3183
3184   OS << "Loop " << L->getHeader()->getName() << ": ";
3185
3186   SmallVector<BasicBlock*, 8> ExitBlocks;
3187   L->getExitBlocks(ExitBlocks);
3188   if (ExitBlocks.size() != 1)
3189     OS << "<multiple exits> ";
3190
3191   if (SE->hasLoopInvariantIterationCount(L)) {
3192     OS << *SE->getIterationCount(L) << " iterations! ";
3193   } else {
3194     OS << "Unpredictable iteration count. ";
3195   }
3196
3197   OS << "\n";
3198 }
3199
3200 void ScalarEvolution::print(std::ostream &OS, const Module* ) const {
3201   Function &F = ((ScalarEvolutionsImpl*)Impl)->F;
3202   LoopInfo &LI = ((ScalarEvolutionsImpl*)Impl)->LI;
3203
3204   OS << "Classifying expressions for: " << F.getName() << "\n";
3205   for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
3206     if (I->getType()->isInteger()) {
3207       OS << *I;
3208       OS << "  -->  ";
3209       SCEVHandle SV = getSCEV(&*I);
3210       SV->print(OS);
3211       OS << "\t\t";
3212
3213       if (const Loop *L = LI.getLoopFor((*I).getParent())) {
3214         OS << "Exits: ";
3215         SCEVHandle ExitValue = getSCEVAtScope(&*I, L->getParentLoop());
3216         if (isa<SCEVCouldNotCompute>(ExitValue)) {
3217           OS << "<<Unknown>>";
3218         } else {
3219           OS << *ExitValue;
3220         }
3221       }
3222
3223
3224       OS << "\n";
3225     }
3226
3227   OS << "Determining loop execution counts for: " << F.getName() << "\n";
3228   for (LoopInfo::iterator I = LI.begin(), E = LI.end(); I != E; ++I)
3229     PrintLoopInfo(OS, this, *I);
3230 }