b150db30c20adafc9418cae245f15b9c9214a3be
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file contains the implementation of the scalar evolution analysis
11 // engine, which is used primarily to analyze expressions involving induction
12 // variables in loops.
13 //
14 // There are several aspects to this library.  First is the representation of
15 // scalar expressions, which are represented as subclasses of the SCEV class.
16 // These classes are used to represent certain types of subexpressions that we
17 // can handle.  These classes are reference counted, managed by the SCEVHandle
18 // class.  We only create one SCEV of a particular shape, so pointer-comparisons
19 // for equality are legal.
20 //
21 // One important aspect of the SCEV objects is that they are never cyclic, even
22 // if there is a cycle in the dataflow for an expression (ie, a PHI node).  If
23 // the PHI node is one of the idioms that we can represent (e.g., a polynomial
24 // recurrence) then we represent it directly as a recurrence node, otherwise we
25 // represent it as a SCEVUnknown node.
26 //
27 // In addition to being able to represent expressions of various types, we also
28 // have folders that are used to build the *canonical* representation for a
29 // particular expression.  These folders are capable of using a variety of
30 // rewrite rules to simplify the expressions.
31 //
32 // Once the folders are defined, we can implement the more interesting
33 // higher-level code, such as the code that recognizes PHI nodes of various
34 // types, computes the execution count of a loop, etc.
35 //
36 // TODO: We should use these routines and value representations to implement
37 // dependence analysis!
38 //
39 //===----------------------------------------------------------------------===//
40 //
41 // There are several good references for the techniques used in this analysis.
42 //
43 //  Chains of recurrences -- a method to expedite the evaluation
44 //  of closed-form functions
45 //  Olaf Bachmann, Paul S. Wang, Eugene V. Zima
46 //
47 //  On computational properties of chains of recurrences
48 //  Eugene V. Zima
49 //
50 //  Symbolic Evaluation of Chains of Recurrences for Loop Optimization
51 //  Robert A. van Engelen
52 //
53 //  Efficient Symbolic Analysis for Optimizing Compilers
54 //  Robert A. van Engelen
55 //
56 //  Using the chains of recurrences algebra for data dependence testing and
57 //  induction variable substitution
58 //  MS Thesis, Johnie Birch
59 //
60 //===----------------------------------------------------------------------===//
61
62 #define DEBUG_TYPE "scalar-evolution"
63 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
64 #include "llvm/Constants.h"
65 #include "llvm/DerivedTypes.h"
66 #include "llvm/GlobalVariable.h"
67 #include "llvm/Instructions.h"
68 #include "llvm/Analysis/ConstantFolding.h"
69 #include "llvm/Analysis/Dominators.h"
70 #include "llvm/Analysis/LoopInfo.h"
71 #include "llvm/Assembly/Writer.h"
72 #include "llvm/Target/TargetData.h"
73 #include "llvm/Support/CommandLine.h"
74 #include "llvm/Support/Compiler.h"
75 #include "llvm/Support/ConstantRange.h"
76 #include "llvm/Support/GetElementPtrTypeIterator.h"
77 #include "llvm/Support/InstIterator.h"
78 #include "llvm/Support/ManagedStatic.h"
79 #include "llvm/Support/MathExtras.h"
80 #include "llvm/Support/raw_ostream.h"
81 #include "llvm/ADT/Statistic.h"
82 #include "llvm/ADT/STLExtras.h"
83 #include <ostream>
84 #include <algorithm>
85 using namespace llvm;
86
87 STATISTIC(NumArrayLenItCounts,
88           "Number of trip counts computed with array length");
89 STATISTIC(NumTripCountsComputed,
90           "Number of loops with predictable loop counts");
91 STATISTIC(NumTripCountsNotComputed,
92           "Number of loops without predictable loop counts");
93 STATISTIC(NumBruteForceTripCountsComputed,
94           "Number of loops with trip counts computed by force");
95
96 static cl::opt<unsigned>
97 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
98                         cl::desc("Maximum number of iterations SCEV will "
99                                  "symbolically execute a constant derived loop"),
100                         cl::init(100));
101
102 static RegisterPass<ScalarEvolution>
103 R("scalar-evolution", "Scalar Evolution Analysis", false, true);
104 char ScalarEvolution::ID = 0;
105
106 //===----------------------------------------------------------------------===//
107 //                           SCEV class definitions
108 //===----------------------------------------------------------------------===//
109
110 //===----------------------------------------------------------------------===//
111 // Implementation of the SCEV class.
112 //
113 SCEV::~SCEV() {}
114 void SCEV::dump() const {
115   print(errs());
116   errs() << '\n';
117 }
118
119 void SCEV::print(std::ostream &o) const {
120   raw_os_ostream OS(o);
121   print(OS);
122 }
123
124 bool SCEV::isZero() const {
125   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
126     return SC->getValue()->isZero();
127   return false;
128 }
129
130 bool SCEV::isOne() const {
131   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
132     return SC->getValue()->isOne();
133   return false;
134 }
135
136 SCEVCouldNotCompute::SCEVCouldNotCompute() : SCEV(scCouldNotCompute) {}
137 SCEVCouldNotCompute::~SCEVCouldNotCompute() {}
138
139 bool SCEVCouldNotCompute::isLoopInvariant(const Loop *L) const {
140   assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
141   return false;
142 }
143
144 const Type *SCEVCouldNotCompute::getType() const {
145   assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
146   return 0;
147 }
148
149 bool SCEVCouldNotCompute::hasComputableLoopEvolution(const Loop *L) const {
150   assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
151   return false;
152 }
153
154 SCEVHandle SCEVCouldNotCompute::
155 replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
156                                   const SCEVHandle &Conc,
157                                   ScalarEvolution &SE) const {
158   return this;
159 }
160
161 void SCEVCouldNotCompute::print(raw_ostream &OS) const {
162   OS << "***COULDNOTCOMPUTE***";
163 }
164
165 bool SCEVCouldNotCompute::classof(const SCEV *S) {
166   return S->getSCEVType() == scCouldNotCompute;
167 }
168
169
170 // SCEVConstants - Only allow the creation of one SCEVConstant for any
171 // particular value.  Don't use a SCEVHandle here, or else the object will
172 // never be deleted!
173 static ManagedStatic<std::map<ConstantInt*, SCEVConstant*> > SCEVConstants;
174
175
176 SCEVConstant::~SCEVConstant() {
177   SCEVConstants->erase(V);
178 }
179
180 SCEVHandle ScalarEvolution::getConstant(ConstantInt *V) {
181   SCEVConstant *&R = (*SCEVConstants)[V];
182   if (R == 0) R = new SCEVConstant(V);
183   return R;
184 }
185
186 SCEVHandle ScalarEvolution::getConstant(const APInt& Val) {
187   return getConstant(ConstantInt::get(Val));
188 }
189
190 const Type *SCEVConstant::getType() const { return V->getType(); }
191
192 void SCEVConstant::print(raw_ostream &OS) const {
193   WriteAsOperand(OS, V, false);
194 }
195
196 SCEVCastExpr::SCEVCastExpr(unsigned SCEVTy,
197                            const SCEVHandle &op, const Type *ty)
198   : SCEV(SCEVTy), Op(op), Ty(ty) {}
199
200 SCEVCastExpr::~SCEVCastExpr() {}
201
202 bool SCEVCastExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
203   return Op->dominates(BB, DT);
204 }
205
206 // SCEVTruncates - Only allow the creation of one SCEVTruncateExpr for any
207 // particular input.  Don't use a SCEVHandle here, or else the object will
208 // never be deleted!
209 static ManagedStatic<std::map<std::pair<const SCEV*, const Type*>, 
210                      SCEVTruncateExpr*> > SCEVTruncates;
211
212 SCEVTruncateExpr::SCEVTruncateExpr(const SCEVHandle &op, const Type *ty)
213   : SCEVCastExpr(scTruncate, op, ty) {
214   assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) &&
215          (Ty->isInteger() || isa<PointerType>(Ty)) &&
216          "Cannot truncate non-integer value!");
217 }
218
219 SCEVTruncateExpr::~SCEVTruncateExpr() {
220   SCEVTruncates->erase(std::make_pair(Op, Ty));
221 }
222
223 void SCEVTruncateExpr::print(raw_ostream &OS) const {
224   OS << "(trunc " << *Op->getType() << " " << *Op << " to " << *Ty << ")";
225 }
226
227 // SCEVZeroExtends - Only allow the creation of one SCEVZeroExtendExpr for any
228 // particular input.  Don't use a SCEVHandle here, or else the object will never
229 // be deleted!
230 static ManagedStatic<std::map<std::pair<const SCEV*, const Type*>,
231                      SCEVZeroExtendExpr*> > SCEVZeroExtends;
232
233 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const SCEVHandle &op, const Type *ty)
234   : SCEVCastExpr(scZeroExtend, op, ty) {
235   assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) &&
236          (Ty->isInteger() || isa<PointerType>(Ty)) &&
237          "Cannot zero extend non-integer value!");
238 }
239
240 SCEVZeroExtendExpr::~SCEVZeroExtendExpr() {
241   SCEVZeroExtends->erase(std::make_pair(Op, Ty));
242 }
243
244 void SCEVZeroExtendExpr::print(raw_ostream &OS) const {
245   OS << "(zext " << *Op->getType() << " " << *Op << " to " << *Ty << ")";
246 }
247
248 // SCEVSignExtends - Only allow the creation of one SCEVSignExtendExpr for any
249 // particular input.  Don't use a SCEVHandle here, or else the object will never
250 // be deleted!
251 static ManagedStatic<std::map<std::pair<const SCEV*, const Type*>,
252                      SCEVSignExtendExpr*> > SCEVSignExtends;
253
254 SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEVHandle &op, const Type *ty)
255   : SCEVCastExpr(scSignExtend, op, ty) {
256   assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) &&
257          (Ty->isInteger() || isa<PointerType>(Ty)) &&
258          "Cannot sign extend non-integer value!");
259 }
260
261 SCEVSignExtendExpr::~SCEVSignExtendExpr() {
262   SCEVSignExtends->erase(std::make_pair(Op, Ty));
263 }
264
265 void SCEVSignExtendExpr::print(raw_ostream &OS) const {
266   OS << "(sext " << *Op->getType() << " " << *Op << " to " << *Ty << ")";
267 }
268
269 // SCEVCommExprs - Only allow the creation of one SCEVCommutativeExpr for any
270 // particular input.  Don't use a SCEVHandle here, or else the object will never
271 // be deleted!
272 static ManagedStatic<std::map<std::pair<unsigned, std::vector<const SCEV*> >,
273                      SCEVCommutativeExpr*> > SCEVCommExprs;
274
275 SCEVCommutativeExpr::~SCEVCommutativeExpr() {
276   std::vector<const SCEV*> SCEVOps(Operands.begin(), Operands.end());
277   SCEVCommExprs->erase(std::make_pair(getSCEVType(), SCEVOps));
278 }
279
280 void SCEVCommutativeExpr::print(raw_ostream &OS) const {
281   assert(Operands.size() > 1 && "This plus expr shouldn't exist!");
282   const char *OpStr = getOperationStr();
283   OS << "(" << *Operands[0];
284   for (unsigned i = 1, e = Operands.size(); i != e; ++i)
285     OS << OpStr << *Operands[i];
286   OS << ")";
287 }
288
289 SCEVHandle SCEVCommutativeExpr::
290 replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
291                                   const SCEVHandle &Conc,
292                                   ScalarEvolution &SE) const {
293   for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
294     SCEVHandle H =
295       getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE);
296     if (H != getOperand(i)) {
297       std::vector<SCEVHandle> NewOps;
298       NewOps.reserve(getNumOperands());
299       for (unsigned j = 0; j != i; ++j)
300         NewOps.push_back(getOperand(j));
301       NewOps.push_back(H);
302       for (++i; i != e; ++i)
303         NewOps.push_back(getOperand(i)->
304                          replaceSymbolicValuesWithConcrete(Sym, Conc, SE));
305
306       if (isa<SCEVAddExpr>(this))
307         return SE.getAddExpr(NewOps);
308       else if (isa<SCEVMulExpr>(this))
309         return SE.getMulExpr(NewOps);
310       else if (isa<SCEVSMaxExpr>(this))
311         return SE.getSMaxExpr(NewOps);
312       else if (isa<SCEVUMaxExpr>(this))
313         return SE.getUMaxExpr(NewOps);
314       else
315         assert(0 && "Unknown commutative expr!");
316     }
317   }
318   return this;
319 }
320
321 bool SCEVNAryExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
322   for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
323     if (!getOperand(i)->dominates(BB, DT))
324       return false;
325   }
326   return true;
327 }
328
329
330 // SCEVUDivs - Only allow the creation of one SCEVUDivExpr for any particular
331 // input.  Don't use a SCEVHandle here, or else the object will never be
332 // deleted!
333 static ManagedStatic<std::map<std::pair<const SCEV*, const SCEV*>,
334                      SCEVUDivExpr*> > SCEVUDivs;
335
336 SCEVUDivExpr::~SCEVUDivExpr() {
337   SCEVUDivs->erase(std::make_pair(LHS, RHS));
338 }
339
340 bool SCEVUDivExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
341   return LHS->dominates(BB, DT) && RHS->dominates(BB, DT);
342 }
343
344 void SCEVUDivExpr::print(raw_ostream &OS) const {
345   OS << "(" << *LHS << " /u " << *RHS << ")";
346 }
347
348 const Type *SCEVUDivExpr::getType() const {
349   return LHS->getType();
350 }
351
352 // SCEVAddRecExprs - Only allow the creation of one SCEVAddRecExpr for any
353 // particular input.  Don't use a SCEVHandle here, or else the object will never
354 // be deleted!
355 static ManagedStatic<std::map<std::pair<const Loop *,
356                                         std::vector<const SCEV*> >,
357                      SCEVAddRecExpr*> > SCEVAddRecExprs;
358
359 SCEVAddRecExpr::~SCEVAddRecExpr() {
360   std::vector<const SCEV*> SCEVOps(Operands.begin(), Operands.end());
361   SCEVAddRecExprs->erase(std::make_pair(L, SCEVOps));
362 }
363
364 SCEVHandle SCEVAddRecExpr::
365 replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
366                                   const SCEVHandle &Conc,
367                                   ScalarEvolution &SE) const {
368   for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
369     SCEVHandle H =
370       getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE);
371     if (H != getOperand(i)) {
372       std::vector<SCEVHandle> NewOps;
373       NewOps.reserve(getNumOperands());
374       for (unsigned j = 0; j != i; ++j)
375         NewOps.push_back(getOperand(j));
376       NewOps.push_back(H);
377       for (++i; i != e; ++i)
378         NewOps.push_back(getOperand(i)->
379                          replaceSymbolicValuesWithConcrete(Sym, Conc, SE));
380
381       return SE.getAddRecExpr(NewOps, L);
382     }
383   }
384   return this;
385 }
386
387
388 bool SCEVAddRecExpr::isLoopInvariant(const Loop *QueryLoop) const {
389   // This recurrence is invariant w.r.t to QueryLoop iff QueryLoop doesn't
390   // contain L and if the start is invariant.
391   // Add recurrences are never invariant in the function-body (null loop).
392   return QueryLoop &&
393          !QueryLoop->contains(L->getHeader()) &&
394          getOperand(0)->isLoopInvariant(QueryLoop);
395 }
396
397
398 void SCEVAddRecExpr::print(raw_ostream &OS) const {
399   OS << "{" << *Operands[0];
400   for (unsigned i = 1, e = Operands.size(); i != e; ++i)
401     OS << ",+," << *Operands[i];
402   OS << "}<" << L->getHeader()->getName() + ">";
403 }
404
405 // SCEVUnknowns - Only allow the creation of one SCEVUnknown for any particular
406 // value.  Don't use a SCEVHandle here, or else the object will never be
407 // deleted!
408 static ManagedStatic<std::map<Value*, SCEVUnknown*> > SCEVUnknowns;
409
410 SCEVUnknown::~SCEVUnknown() { SCEVUnknowns->erase(V); }
411
412 bool SCEVUnknown::isLoopInvariant(const Loop *L) const {
413   // All non-instruction values are loop invariant.  All instructions are loop
414   // invariant if they are not contained in the specified loop.
415   // Instructions are never considered invariant in the function body
416   // (null loop) because they are defined within the "loop".
417   if (Instruction *I = dyn_cast<Instruction>(V))
418     return L && !L->contains(I->getParent());
419   return true;
420 }
421
422 bool SCEVUnknown::dominates(BasicBlock *BB, DominatorTree *DT) const {
423   if (Instruction *I = dyn_cast<Instruction>(getValue()))
424     return DT->dominates(I->getParent(), BB);
425   return true;
426 }
427
428 const Type *SCEVUnknown::getType() const {
429   return V->getType();
430 }
431
432 void SCEVUnknown::print(raw_ostream &OS) const {
433   WriteAsOperand(OS, V, false);
434 }
435
436 //===----------------------------------------------------------------------===//
437 //                               SCEV Utilities
438 //===----------------------------------------------------------------------===//
439
440 namespace {
441   /// SCEVComplexityCompare - Return true if the complexity of the LHS is less
442   /// than the complexity of the RHS.  This comparator is used to canonicalize
443   /// expressions.
444   class VISIBILITY_HIDDEN SCEVComplexityCompare {
445     LoopInfo *LI;
446   public:
447     explicit SCEVComplexityCompare(LoopInfo *li) : LI(li) {}
448
449     bool operator()(const SCEV *LHS, const SCEV *RHS) const {
450       // Primarily, sort the SCEVs by their getSCEVType().
451       if (LHS->getSCEVType() != RHS->getSCEVType())
452         return LHS->getSCEVType() < RHS->getSCEVType();
453
454       // Aside from the getSCEVType() ordering, the particular ordering
455       // isn't very important except that it's beneficial to be consistent,
456       // so that (a + b) and (b + a) don't end up as different expressions.
457
458       // Sort SCEVUnknown values with some loose heuristics. TODO: This is
459       // not as complete as it could be.
460       if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS)) {
461         const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
462
463         // Order pointer values after integer values. This helps SCEVExpander
464         // form GEPs.
465         if (isa<PointerType>(LU->getType()) && !isa<PointerType>(RU->getType()))
466           return false;
467         if (isa<PointerType>(RU->getType()) && !isa<PointerType>(LU->getType()))
468           return true;
469
470         // Compare getValueID values.
471         if (LU->getValue()->getValueID() != RU->getValue()->getValueID())
472           return LU->getValue()->getValueID() < RU->getValue()->getValueID();
473
474         // Sort arguments by their position.
475         if (const Argument *LA = dyn_cast<Argument>(LU->getValue())) {
476           const Argument *RA = cast<Argument>(RU->getValue());
477           return LA->getArgNo() < RA->getArgNo();
478         }
479
480         // For instructions, compare their loop depth, and their opcode.
481         // This is pretty loose.
482         if (Instruction *LV = dyn_cast<Instruction>(LU->getValue())) {
483           Instruction *RV = cast<Instruction>(RU->getValue());
484
485           // Compare loop depths.
486           if (LI->getLoopDepth(LV->getParent()) !=
487               LI->getLoopDepth(RV->getParent()))
488             return LI->getLoopDepth(LV->getParent()) <
489                    LI->getLoopDepth(RV->getParent());
490
491           // Compare opcodes.
492           if (LV->getOpcode() != RV->getOpcode())
493             return LV->getOpcode() < RV->getOpcode();
494
495           // Compare the number of operands.
496           if (LV->getNumOperands() != RV->getNumOperands())
497             return LV->getNumOperands() < RV->getNumOperands();
498         }
499
500         return false;
501       }
502
503       // Constant sorting doesn't matter since they'll be folded.
504       if (isa<SCEVConstant>(LHS))
505         return false;
506
507       // Lexicographically compare n-ary expressions.
508       if (const SCEVNAryExpr *LC = dyn_cast<SCEVNAryExpr>(LHS)) {
509         const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
510         for (unsigned i = 0, e = LC->getNumOperands(); i != e; ++i) {
511           if (i >= RC->getNumOperands())
512             return false;
513           if (operator()(LC->getOperand(i), RC->getOperand(i)))
514             return true;
515           if (operator()(RC->getOperand(i), LC->getOperand(i)))
516             return false;
517         }
518         return LC->getNumOperands() < RC->getNumOperands();
519       }
520
521       // Lexicographically compare udiv expressions.
522       if (const SCEVUDivExpr *LC = dyn_cast<SCEVUDivExpr>(LHS)) {
523         const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
524         if (operator()(LC->getLHS(), RC->getLHS()))
525           return true;
526         if (operator()(RC->getLHS(), LC->getLHS()))
527           return false;
528         if (operator()(LC->getRHS(), RC->getRHS()))
529           return true;
530         if (operator()(RC->getRHS(), LC->getRHS()))
531           return false;
532         return false;
533       }
534
535       // Compare cast expressions by operand.
536       if (const SCEVCastExpr *LC = dyn_cast<SCEVCastExpr>(LHS)) {
537         const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
538         return operator()(LC->getOperand(), RC->getOperand());
539       }
540
541       assert(0 && "Unknown SCEV kind!");
542       return false;
543     }
544   };
545 }
546
547 /// GroupByComplexity - Given a list of SCEV objects, order them by their
548 /// complexity, and group objects of the same complexity together by value.
549 /// When this routine is finished, we know that any duplicates in the vector are
550 /// consecutive and that complexity is monotonically increasing.
551 ///
552 /// Note that we go take special precautions to ensure that we get determinstic
553 /// results from this routine.  In other words, we don't want the results of
554 /// this to depend on where the addresses of various SCEV objects happened to
555 /// land in memory.
556 ///
557 static void GroupByComplexity(std::vector<SCEVHandle> &Ops,
558                               LoopInfo *LI) {
559   if (Ops.size() < 2) return;  // Noop
560   if (Ops.size() == 2) {
561     // This is the common case, which also happens to be trivially simple.
562     // Special case it.
563     if (SCEVComplexityCompare(LI)(Ops[1], Ops[0]))
564       std::swap(Ops[0], Ops[1]);
565     return;
566   }
567
568   // Do the rough sort by complexity.
569   std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare(LI));
570
571   // Now that we are sorted by complexity, group elements of the same
572   // complexity.  Note that this is, at worst, N^2, but the vector is likely to
573   // be extremely short in practice.  Note that we take this approach because we
574   // do not want to depend on the addresses of the objects we are grouping.
575   for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
576     const SCEV *S = Ops[i];
577     unsigned Complexity = S->getSCEVType();
578
579     // If there are any objects of the same complexity and same value as this
580     // one, group them.
581     for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
582       if (Ops[j] == S) { // Found a duplicate.
583         // Move it to immediately after i'th element.
584         std::swap(Ops[i+1], Ops[j]);
585         ++i;   // no need to rescan it.
586         if (i == e-2) return;  // Done!
587       }
588     }
589   }
590 }
591
592
593
594 //===----------------------------------------------------------------------===//
595 //                      Simple SCEV method implementations
596 //===----------------------------------------------------------------------===//
597
598 /// BinomialCoefficient - Compute BC(It, K).  The result has width W.
599 /// Assume, K > 0.
600 static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K,
601                                       ScalarEvolution &SE,
602                                       const Type* ResultTy) {
603   // Handle the simplest case efficiently.
604   if (K == 1)
605     return SE.getTruncateOrZeroExtend(It, ResultTy);
606
607   // We are using the following formula for BC(It, K):
608   //
609   //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
610   //
611   // Suppose, W is the bitwidth of the return value.  We must be prepared for
612   // overflow.  Hence, we must assure that the result of our computation is
613   // equal to the accurate one modulo 2^W.  Unfortunately, division isn't
614   // safe in modular arithmetic.
615   //
616   // However, this code doesn't use exactly that formula; the formula it uses
617   // is something like the following, where T is the number of factors of 2 in 
618   // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
619   // exponentiation:
620   //
621   //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
622   //
623   // This formula is trivially equivalent to the previous formula.  However,
624   // this formula can be implemented much more efficiently.  The trick is that
625   // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
626   // arithmetic.  To do exact division in modular arithmetic, all we have
627   // to do is multiply by the inverse.  Therefore, this step can be done at
628   // width W.
629   // 
630   // The next issue is how to safely do the division by 2^T.  The way this
631   // is done is by doing the multiplication step at a width of at least W + T
632   // bits.  This way, the bottom W+T bits of the product are accurate. Then,
633   // when we perform the division by 2^T (which is equivalent to a right shift
634   // by T), the bottom W bits are accurate.  Extra bits are okay; they'll get
635   // truncated out after the division by 2^T.
636   //
637   // In comparison to just directly using the first formula, this technique
638   // is much more efficient; using the first formula requires W * K bits,
639   // but this formula less than W + K bits. Also, the first formula requires
640   // a division step, whereas this formula only requires multiplies and shifts.
641   //
642   // It doesn't matter whether the subtraction step is done in the calculation
643   // width or the input iteration count's width; if the subtraction overflows,
644   // the result must be zero anyway.  We prefer here to do it in the width of
645   // the induction variable because it helps a lot for certain cases; CodeGen
646   // isn't smart enough to ignore the overflow, which leads to much less
647   // efficient code if the width of the subtraction is wider than the native
648   // register width.
649   //
650   // (It's possible to not widen at all by pulling out factors of 2 before
651   // the multiplication; for example, K=2 can be calculated as
652   // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
653   // extra arithmetic, so it's not an obvious win, and it gets
654   // much more complicated for K > 3.)
655
656   // Protection from insane SCEVs; this bound is conservative,
657   // but it probably doesn't matter.
658   if (K > 1000)
659     return SE.getCouldNotCompute();
660
661   unsigned W = SE.getTypeSizeInBits(ResultTy);
662
663   // Calculate K! / 2^T and T; we divide out the factors of two before
664   // multiplying for calculating K! / 2^T to avoid overflow.
665   // Other overflow doesn't matter because we only care about the bottom
666   // W bits of the result.
667   APInt OddFactorial(W, 1);
668   unsigned T = 1;
669   for (unsigned i = 3; i <= K; ++i) {
670     APInt Mult(W, i);
671     unsigned TwoFactors = Mult.countTrailingZeros();
672     T += TwoFactors;
673     Mult = Mult.lshr(TwoFactors);
674     OddFactorial *= Mult;
675   }
676
677   // We need at least W + T bits for the multiplication step
678   unsigned CalculationBits = W + T;
679
680   // Calcuate 2^T, at width T+W.
681   APInt DivFactor = APInt(CalculationBits, 1).shl(T);
682
683   // Calculate the multiplicative inverse of K! / 2^T;
684   // this multiplication factor will perform the exact division by
685   // K! / 2^T.
686   APInt Mod = APInt::getSignedMinValue(W+1);
687   APInt MultiplyFactor = OddFactorial.zext(W+1);
688   MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
689   MultiplyFactor = MultiplyFactor.trunc(W);
690
691   // Calculate the product, at width T+W
692   const IntegerType *CalculationTy = IntegerType::get(CalculationBits);
693   SCEVHandle Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
694   for (unsigned i = 1; i != K; ++i) {
695     SCEVHandle S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType()));
696     Dividend = SE.getMulExpr(Dividend,
697                              SE.getTruncateOrZeroExtend(S, CalculationTy));
698   }
699
700   // Divide by 2^T
701   SCEVHandle DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
702
703   // Truncate the result, and divide by K! / 2^T.
704
705   return SE.getMulExpr(SE.getConstant(MultiplyFactor),
706                        SE.getTruncateOrZeroExtend(DivResult, ResultTy));
707 }
708
709 /// evaluateAtIteration - Return the value of this chain of recurrences at
710 /// the specified iteration number.  We can evaluate this recurrence by
711 /// multiplying each element in the chain by the binomial coefficient
712 /// corresponding to it.  In other words, we can evaluate {A,+,B,+,C,+,D} as:
713 ///
714 ///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
715 ///
716 /// where BC(It, k) stands for binomial coefficient.
717 ///
718 SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It,
719                                                ScalarEvolution &SE) const {
720   SCEVHandle Result = getStart();
721   for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
722     // The computation is correct in the face of overflow provided that the
723     // multiplication is performed _after_ the evaluation of the binomial
724     // coefficient.
725     SCEVHandle Coeff = BinomialCoefficient(It, i, SE, getType());
726     if (isa<SCEVCouldNotCompute>(Coeff))
727       return Coeff;
728
729     Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
730   }
731   return Result;
732 }
733
734 //===----------------------------------------------------------------------===//
735 //                    SCEV Expression folder implementations
736 //===----------------------------------------------------------------------===//
737
738 SCEVHandle ScalarEvolution::getTruncateExpr(const SCEVHandle &Op,
739                                             const Type *Ty) {
740   assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
741          "This is not a truncating conversion!");
742   assert(isSCEVable(Ty) &&
743          "This is not a conversion to a SCEVable type!");
744   Ty = getEffectiveSCEVType(Ty);
745
746   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
747     return getUnknown(
748         ConstantExpr::getTrunc(SC->getValue(), Ty));
749
750   // trunc(trunc(x)) --> trunc(x)
751   if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
752     return getTruncateExpr(ST->getOperand(), Ty);
753
754   // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
755   if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
756     return getTruncateOrSignExtend(SS->getOperand(), Ty);
757
758   // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
759   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
760     return getTruncateOrZeroExtend(SZ->getOperand(), Ty);
761
762   // If the input value is a chrec scev made out of constants, truncate
763   // all of the constants.
764   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
765     std::vector<SCEVHandle> Operands;
766     for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
767       Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty));
768     return getAddRecExpr(Operands, AddRec->getLoop());
769   }
770
771   SCEVTruncateExpr *&Result = (*SCEVTruncates)[std::make_pair(Op, Ty)];
772   if (Result == 0) Result = new SCEVTruncateExpr(Op, Ty);
773   return Result;
774 }
775
776 SCEVHandle ScalarEvolution::getZeroExtendExpr(const SCEVHandle &Op,
777                                               const Type *Ty) {
778   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
779          "This is not an extending conversion!");
780   assert(isSCEVable(Ty) &&
781          "This is not a conversion to a SCEVable type!");
782   Ty = getEffectiveSCEVType(Ty);
783
784   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) {
785     const Type *IntTy = getEffectiveSCEVType(Ty);
786     Constant *C = ConstantExpr::getZExt(SC->getValue(), IntTy);
787     if (IntTy != Ty) C = ConstantExpr::getIntToPtr(C, Ty);
788     return getUnknown(C);
789   }
790
791   // zext(zext(x)) --> zext(x)
792   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
793     return getZeroExtendExpr(SZ->getOperand(), Ty);
794
795   // If the input value is a chrec scev, and we can prove that the value
796   // did not overflow the old, smaller, value, we can zero extend all of the
797   // operands (often constants).  This allows analysis of something like
798   // this:  for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
799   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
800     if (AR->isAffine()) {
801       // Check whether the backedge-taken count is SCEVCouldNotCompute.
802       // Note that this serves two purposes: It filters out loops that are
803       // simply not analyzable, and it covers the case where this code is
804       // being called from within backedge-taken count analysis, such that
805       // attempting to ask for the backedge-taken count would likely result
806       // in infinite recursion. In the later case, the analysis code will
807       // cope with a conservative value, and it will take care to purge
808       // that value once it has finished.
809       SCEVHandle MaxBECount = getMaxBackedgeTakenCount(AR->getLoop());
810       if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
811         // Manually compute the final value for AR, checking for
812         // overflow.
813         SCEVHandle Start = AR->getStart();
814         SCEVHandle Step = AR->getStepRecurrence(*this);
815
816         // Check whether the backedge-taken count can be losslessly casted to
817         // the addrec's type. The count is always unsigned.
818         SCEVHandle CastedMaxBECount =
819           getTruncateOrZeroExtend(MaxBECount, Start->getType());
820         SCEVHandle RecastedMaxBECount =
821           getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
822         if (MaxBECount == RecastedMaxBECount) {
823           const Type *WideTy =
824             IntegerType::get(getTypeSizeInBits(Start->getType()) * 2);
825           // Check whether Start+Step*MaxBECount has no unsigned overflow.
826           SCEVHandle ZMul =
827             getMulExpr(CastedMaxBECount,
828                        getTruncateOrZeroExtend(Step, Start->getType()));
829           SCEVHandle Add = getAddExpr(Start, ZMul);
830           SCEVHandle OperandExtendedAdd =
831             getAddExpr(getZeroExtendExpr(Start, WideTy),
832                        getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
833                                   getZeroExtendExpr(Step, WideTy)));
834           if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd)
835             // Return the expression with the addrec on the outside.
836             return getAddRecExpr(getZeroExtendExpr(Start, Ty),
837                                  getZeroExtendExpr(Step, Ty),
838                                  AR->getLoop());
839
840           // Similar to above, only this time treat the step value as signed.
841           // This covers loops that count down.
842           SCEVHandle SMul =
843             getMulExpr(CastedMaxBECount,
844                        getTruncateOrSignExtend(Step, Start->getType()));
845           Add = getAddExpr(Start, SMul);
846           OperandExtendedAdd =
847             getAddExpr(getZeroExtendExpr(Start, WideTy),
848                        getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
849                                   getSignExtendExpr(Step, WideTy)));
850           if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd)
851             // Return the expression with the addrec on the outside.
852             return getAddRecExpr(getZeroExtendExpr(Start, Ty),
853                                  getSignExtendExpr(Step, Ty),
854                                  AR->getLoop());
855         }
856       }
857     }
858
859   SCEVZeroExtendExpr *&Result = (*SCEVZeroExtends)[std::make_pair(Op, Ty)];
860   if (Result == 0) Result = new SCEVZeroExtendExpr(Op, Ty);
861   return Result;
862 }
863
864 SCEVHandle ScalarEvolution::getSignExtendExpr(const SCEVHandle &Op,
865                                               const Type *Ty) {
866   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
867          "This is not an extending conversion!");
868   assert(isSCEVable(Ty) &&
869          "This is not a conversion to a SCEVable type!");
870   Ty = getEffectiveSCEVType(Ty);
871
872   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) {
873     const Type *IntTy = getEffectiveSCEVType(Ty);
874     Constant *C = ConstantExpr::getSExt(SC->getValue(), IntTy);
875     if (IntTy != Ty) C = ConstantExpr::getIntToPtr(C, Ty);
876     return getUnknown(C);
877   }
878
879   // sext(sext(x)) --> sext(x)
880   if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
881     return getSignExtendExpr(SS->getOperand(), Ty);
882
883   // If the input value is a chrec scev, and we can prove that the value
884   // did not overflow the old, smaller, value, we can sign extend all of the
885   // operands (often constants).  This allows analysis of something like
886   // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
887   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
888     if (AR->isAffine()) {
889       // Check whether the backedge-taken count is SCEVCouldNotCompute.
890       // Note that this serves two purposes: It filters out loops that are
891       // simply not analyzable, and it covers the case where this code is
892       // being called from within backedge-taken count analysis, such that
893       // attempting to ask for the backedge-taken count would likely result
894       // in infinite recursion. In the later case, the analysis code will
895       // cope with a conservative value, and it will take care to purge
896       // that value once it has finished.
897       SCEVHandle MaxBECount = getMaxBackedgeTakenCount(AR->getLoop());
898       if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
899         // Manually compute the final value for AR, checking for
900         // overflow.
901         SCEVHandle Start = AR->getStart();
902         SCEVHandle Step = AR->getStepRecurrence(*this);
903
904         // Check whether the backedge-taken count can be losslessly casted to
905         // the addrec's type. The count is always unsigned.
906         SCEVHandle CastedMaxBECount =
907           getTruncateOrZeroExtend(MaxBECount, Start->getType());
908         SCEVHandle RecastedMaxBECount =
909           getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
910         if (MaxBECount == RecastedMaxBECount) {
911           const Type *WideTy =
912             IntegerType::get(getTypeSizeInBits(Start->getType()) * 2);
913           // Check whether Start+Step*MaxBECount has no signed overflow.
914           SCEVHandle SMul =
915             getMulExpr(CastedMaxBECount,
916                        getTruncateOrSignExtend(Step, Start->getType()));
917           SCEVHandle Add = getAddExpr(Start, SMul);
918           SCEVHandle OperandExtendedAdd =
919             getAddExpr(getSignExtendExpr(Start, WideTy),
920                        getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
921                                   getSignExtendExpr(Step, WideTy)));
922           if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd)
923             // Return the expression with the addrec on the outside.
924             return getAddRecExpr(getSignExtendExpr(Start, Ty),
925                                  getSignExtendExpr(Step, Ty),
926                                  AR->getLoop());
927         }
928       }
929     }
930
931   SCEVSignExtendExpr *&Result = (*SCEVSignExtends)[std::make_pair(Op, Ty)];
932   if (Result == 0) Result = new SCEVSignExtendExpr(Op, Ty);
933   return Result;
934 }
935
936 /// getAddExpr - Get a canonical add expression, or something simpler if
937 /// possible.
938 SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) {
939   assert(!Ops.empty() && "Cannot get empty add!");
940   if (Ops.size() == 1) return Ops[0];
941 #ifndef NDEBUG
942   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
943     assert(getEffectiveSCEVType(Ops[i]->getType()) ==
944            getEffectiveSCEVType(Ops[0]->getType()) &&
945            "SCEVAddExpr operand types don't match!");
946 #endif
947
948   // Sort by complexity, this groups all similar expression types together.
949   GroupByComplexity(Ops, LI);
950
951   // If there are any constants, fold them together.
952   unsigned Idx = 0;
953   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
954     ++Idx;
955     assert(Idx < Ops.size());
956     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
957       // We found two constants, fold them together!
958       ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() + 
959                                            RHSC->getValue()->getValue());
960       Ops[0] = getConstant(Fold);
961       Ops.erase(Ops.begin()+1);  // Erase the folded element
962       if (Ops.size() == 1) return Ops[0];
963       LHSC = cast<SCEVConstant>(Ops[0]);
964     }
965
966     // If we are left with a constant zero being added, strip it off.
967     if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
968       Ops.erase(Ops.begin());
969       --Idx;
970     }
971   }
972
973   if (Ops.size() == 1) return Ops[0];
974
975   // Okay, check to see if the same value occurs in the operand list twice.  If
976   // so, merge them together into an multiply expression.  Since we sorted the
977   // list, these values are required to be adjacent.
978   const Type *Ty = Ops[0]->getType();
979   for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
980     if (Ops[i] == Ops[i+1]) {      //  X + Y + Y  -->  X + Y*2
981       // Found a match, merge the two values into a multiply, and add any
982       // remaining values to the result.
983       SCEVHandle Two = getIntegerSCEV(2, Ty);
984       SCEVHandle Mul = getMulExpr(Ops[i], Two);
985       if (Ops.size() == 2)
986         return Mul;
987       Ops.erase(Ops.begin()+i, Ops.begin()+i+2);
988       Ops.push_back(Mul);
989       return getAddExpr(Ops);
990     }
991
992   // Check for truncates. If all the operands are truncated from the same
993   // type, see if factoring out the truncate would permit the result to be
994   // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n)
995   // if the contents of the resulting outer trunc fold to something simple.
996   for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) {
997     const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]);
998     const Type *DstType = Trunc->getType();
999     const Type *SrcType = Trunc->getOperand()->getType();
1000     std::vector<SCEVHandle> LargeOps;
1001     bool Ok = true;
1002     // Check all the operands to see if they can be represented in the
1003     // source type of the truncate.
1004     for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
1005       if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
1006         if (T->getOperand()->getType() != SrcType) {
1007           Ok = false;
1008           break;
1009         }
1010         LargeOps.push_back(T->getOperand());
1011       } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1012         // This could be either sign or zero extension, but sign extension
1013         // is much more likely to be foldable here.
1014         LargeOps.push_back(getSignExtendExpr(C, SrcType));
1015       } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
1016         std::vector<SCEVHandle> LargeMulOps;
1017         for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
1018           if (const SCEVTruncateExpr *T =
1019                 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
1020             if (T->getOperand()->getType() != SrcType) {
1021               Ok = false;
1022               break;
1023             }
1024             LargeMulOps.push_back(T->getOperand());
1025           } else if (const SCEVConstant *C =
1026                        dyn_cast<SCEVConstant>(M->getOperand(j))) {
1027             // This could be either sign or zero extension, but sign extension
1028             // is much more likely to be foldable here.
1029             LargeMulOps.push_back(getSignExtendExpr(C, SrcType));
1030           } else {
1031             Ok = false;
1032             break;
1033           }
1034         }
1035         if (Ok)
1036           LargeOps.push_back(getMulExpr(LargeMulOps));
1037       } else {
1038         Ok = false;
1039         break;
1040       }
1041     }
1042     if (Ok) {
1043       // Evaluate the expression in the larger type.
1044       SCEVHandle Fold = getAddExpr(LargeOps);
1045       // If it folds to something simple, use it. Otherwise, don't.
1046       if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
1047         return getTruncateExpr(Fold, DstType);
1048     }
1049   }
1050
1051   // Skip past any other cast SCEVs.
1052   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
1053     ++Idx;
1054
1055   // If there are add operands they would be next.
1056   if (Idx < Ops.size()) {
1057     bool DeletedAdd = false;
1058     while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
1059       // If we have an add, expand the add operands onto the end of the operands
1060       // list.
1061       Ops.insert(Ops.end(), Add->op_begin(), Add->op_end());
1062       Ops.erase(Ops.begin()+Idx);
1063       DeletedAdd = true;
1064     }
1065
1066     // If we deleted at least one add, we added operands to the end of the list,
1067     // and they are not necessarily sorted.  Recurse to resort and resimplify
1068     // any operands we just aquired.
1069     if (DeletedAdd)
1070       return getAddExpr(Ops);
1071   }
1072
1073   // Skip over the add expression until we get to a multiply.
1074   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1075     ++Idx;
1076
1077   // If we are adding something to a multiply expression, make sure the
1078   // something is not already an operand of the multiply.  If so, merge it into
1079   // the multiply.
1080   for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
1081     const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
1082     for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
1083       const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
1084       for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
1085         if (MulOpSCEV == Ops[AddOp] && !isa<SCEVConstant>(MulOpSCEV)) {
1086           // Fold W + X + (X * Y * Z)  -->  W + (X * ((Y*Z)+1))
1087           SCEVHandle InnerMul = Mul->getOperand(MulOp == 0);
1088           if (Mul->getNumOperands() != 2) {
1089             // If the multiply has more than two operands, we must get the
1090             // Y*Z term.
1091             std::vector<SCEVHandle> MulOps(Mul->op_begin(), Mul->op_end());
1092             MulOps.erase(MulOps.begin()+MulOp);
1093             InnerMul = getMulExpr(MulOps);
1094           }
1095           SCEVHandle One = getIntegerSCEV(1, Ty);
1096           SCEVHandle AddOne = getAddExpr(InnerMul, One);
1097           SCEVHandle OuterMul = getMulExpr(AddOne, Ops[AddOp]);
1098           if (Ops.size() == 2) return OuterMul;
1099           if (AddOp < Idx) {
1100             Ops.erase(Ops.begin()+AddOp);
1101             Ops.erase(Ops.begin()+Idx-1);
1102           } else {
1103             Ops.erase(Ops.begin()+Idx);
1104             Ops.erase(Ops.begin()+AddOp-1);
1105           }
1106           Ops.push_back(OuterMul);
1107           return getAddExpr(Ops);
1108         }
1109
1110       // Check this multiply against other multiplies being added together.
1111       for (unsigned OtherMulIdx = Idx+1;
1112            OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
1113            ++OtherMulIdx) {
1114         const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
1115         // If MulOp occurs in OtherMul, we can fold the two multiplies
1116         // together.
1117         for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
1118              OMulOp != e; ++OMulOp)
1119           if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
1120             // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
1121             SCEVHandle InnerMul1 = Mul->getOperand(MulOp == 0);
1122             if (Mul->getNumOperands() != 2) {
1123               std::vector<SCEVHandle> MulOps(Mul->op_begin(), Mul->op_end());
1124               MulOps.erase(MulOps.begin()+MulOp);
1125               InnerMul1 = getMulExpr(MulOps);
1126             }
1127             SCEVHandle InnerMul2 = OtherMul->getOperand(OMulOp == 0);
1128             if (OtherMul->getNumOperands() != 2) {
1129               std::vector<SCEVHandle> MulOps(OtherMul->op_begin(),
1130                                              OtherMul->op_end());
1131               MulOps.erase(MulOps.begin()+OMulOp);
1132               InnerMul2 = getMulExpr(MulOps);
1133             }
1134             SCEVHandle InnerMulSum = getAddExpr(InnerMul1,InnerMul2);
1135             SCEVHandle OuterMul = getMulExpr(MulOpSCEV, InnerMulSum);
1136             if (Ops.size() == 2) return OuterMul;
1137             Ops.erase(Ops.begin()+Idx);
1138             Ops.erase(Ops.begin()+OtherMulIdx-1);
1139             Ops.push_back(OuterMul);
1140             return getAddExpr(Ops);
1141           }
1142       }
1143     }
1144   }
1145
1146   // If there are any add recurrences in the operands list, see if any other
1147   // added values are loop invariant.  If so, we can fold them into the
1148   // recurrence.
1149   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1150     ++Idx;
1151
1152   // Scan over all recurrences, trying to fold loop invariants into them.
1153   for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1154     // Scan all of the other operands to this add and add them to the vector if
1155     // they are loop invariant w.r.t. the recurrence.
1156     std::vector<SCEVHandle> LIOps;
1157     const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1158     for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1159       if (Ops[i]->isLoopInvariant(AddRec->getLoop())) {
1160         LIOps.push_back(Ops[i]);
1161         Ops.erase(Ops.begin()+i);
1162         --i; --e;
1163       }
1164
1165     // If we found some loop invariants, fold them into the recurrence.
1166     if (!LIOps.empty()) {
1167       //  NLI + LI + {Start,+,Step}  -->  NLI + {LI+Start,+,Step}
1168       LIOps.push_back(AddRec->getStart());
1169
1170       std::vector<SCEVHandle> AddRecOps(AddRec->op_begin(), AddRec->op_end());
1171       AddRecOps[0] = getAddExpr(LIOps);
1172
1173       SCEVHandle NewRec = getAddRecExpr(AddRecOps, AddRec->getLoop());
1174       // If all of the other operands were loop invariant, we are done.
1175       if (Ops.size() == 1) return NewRec;
1176
1177       // Otherwise, add the folded AddRec by the non-liv parts.
1178       for (unsigned i = 0;; ++i)
1179         if (Ops[i] == AddRec) {
1180           Ops[i] = NewRec;
1181           break;
1182         }
1183       return getAddExpr(Ops);
1184     }
1185
1186     // Okay, if there weren't any loop invariants to be folded, check to see if
1187     // there are multiple AddRec's with the same loop induction variable being
1188     // added together.  If so, we can fold them.
1189     for (unsigned OtherIdx = Idx+1;
1190          OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx)
1191       if (OtherIdx != Idx) {
1192         const SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
1193         if (AddRec->getLoop() == OtherAddRec->getLoop()) {
1194           // Other + {A,+,B} + {C,+,D}  -->  Other + {A+C,+,B+D}
1195           std::vector<SCEVHandle> NewOps(AddRec->op_begin(), AddRec->op_end());
1196           for (unsigned i = 0, e = OtherAddRec->getNumOperands(); i != e; ++i) {
1197             if (i >= NewOps.size()) {
1198               NewOps.insert(NewOps.end(), OtherAddRec->op_begin()+i,
1199                             OtherAddRec->op_end());
1200               break;
1201             }
1202             NewOps[i] = getAddExpr(NewOps[i], OtherAddRec->getOperand(i));
1203           }
1204           SCEVHandle NewAddRec = getAddRecExpr(NewOps, AddRec->getLoop());
1205
1206           if (Ops.size() == 2) return NewAddRec;
1207
1208           Ops.erase(Ops.begin()+Idx);
1209           Ops.erase(Ops.begin()+OtherIdx-1);
1210           Ops.push_back(NewAddRec);
1211           return getAddExpr(Ops);
1212         }
1213       }
1214
1215     // Otherwise couldn't fold anything into this recurrence.  Move onto the
1216     // next one.
1217   }
1218
1219   // Okay, it looks like we really DO need an add expr.  Check to see if we
1220   // already have one, otherwise create a new one.
1221   std::vector<const SCEV*> SCEVOps(Ops.begin(), Ops.end());
1222   SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scAddExpr,
1223                                                                  SCEVOps)];
1224   if (Result == 0) Result = new SCEVAddExpr(Ops);
1225   return Result;
1226 }
1227
1228
1229 /// getMulExpr - Get a canonical multiply expression, or something simpler if
1230 /// possible.
1231 SCEVHandle ScalarEvolution::getMulExpr(std::vector<SCEVHandle> &Ops) {
1232   assert(!Ops.empty() && "Cannot get empty mul!");
1233 #ifndef NDEBUG
1234   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1235     assert(getEffectiveSCEVType(Ops[i]->getType()) ==
1236            getEffectiveSCEVType(Ops[0]->getType()) &&
1237            "SCEVMulExpr operand types don't match!");
1238 #endif
1239
1240   // Sort by complexity, this groups all similar expression types together.
1241   GroupByComplexity(Ops, LI);
1242
1243   // If there are any constants, fold them together.
1244   unsigned Idx = 0;
1245   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1246
1247     // C1*(C2+V) -> C1*C2 + C1*V
1248     if (Ops.size() == 2)
1249       if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
1250         if (Add->getNumOperands() == 2 &&
1251             isa<SCEVConstant>(Add->getOperand(0)))
1252           return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
1253                             getMulExpr(LHSC, Add->getOperand(1)));
1254
1255
1256     ++Idx;
1257     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1258       // We found two constants, fold them together!
1259       ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() * 
1260                                            RHSC->getValue()->getValue());
1261       Ops[0] = getConstant(Fold);
1262       Ops.erase(Ops.begin()+1);  // Erase the folded element
1263       if (Ops.size() == 1) return Ops[0];
1264       LHSC = cast<SCEVConstant>(Ops[0]);
1265     }
1266
1267     // If we are left with a constant one being multiplied, strip it off.
1268     if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) {
1269       Ops.erase(Ops.begin());
1270       --Idx;
1271     } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
1272       // If we have a multiply of zero, it will always be zero.
1273       return Ops[0];
1274     }
1275   }
1276
1277   // Skip over the add expression until we get to a multiply.
1278   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1279     ++Idx;
1280
1281   if (Ops.size() == 1)
1282     return Ops[0];
1283
1284   // If there are mul operands inline them all into this expression.
1285   if (Idx < Ops.size()) {
1286     bool DeletedMul = false;
1287     while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
1288       // If we have an mul, expand the mul operands onto the end of the operands
1289       // list.
1290       Ops.insert(Ops.end(), Mul->op_begin(), Mul->op_end());
1291       Ops.erase(Ops.begin()+Idx);
1292       DeletedMul = true;
1293     }
1294
1295     // If we deleted at least one mul, we added operands to the end of the list,
1296     // and they are not necessarily sorted.  Recurse to resort and resimplify
1297     // any operands we just aquired.
1298     if (DeletedMul)
1299       return getMulExpr(Ops);
1300   }
1301
1302   // If there are any add recurrences in the operands list, see if any other
1303   // added values are loop invariant.  If so, we can fold them into the
1304   // recurrence.
1305   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1306     ++Idx;
1307
1308   // Scan over all recurrences, trying to fold loop invariants into them.
1309   for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1310     // Scan all of the other operands to this mul and add them to the vector if
1311     // they are loop invariant w.r.t. the recurrence.
1312     std::vector<SCEVHandle> LIOps;
1313     const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1314     for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1315       if (Ops[i]->isLoopInvariant(AddRec->getLoop())) {
1316         LIOps.push_back(Ops[i]);
1317         Ops.erase(Ops.begin()+i);
1318         --i; --e;
1319       }
1320
1321     // If we found some loop invariants, fold them into the recurrence.
1322     if (!LIOps.empty()) {
1323       //  NLI * LI * {Start,+,Step}  -->  NLI * {LI*Start,+,LI*Step}
1324       std::vector<SCEVHandle> NewOps;
1325       NewOps.reserve(AddRec->getNumOperands());
1326       if (LIOps.size() == 1) {
1327         const SCEV *Scale = LIOps[0];
1328         for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
1329           NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i)));
1330       } else {
1331         for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
1332           std::vector<SCEVHandle> MulOps(LIOps);
1333           MulOps.push_back(AddRec->getOperand(i));
1334           NewOps.push_back(getMulExpr(MulOps));
1335         }
1336       }
1337
1338       SCEVHandle NewRec = getAddRecExpr(NewOps, AddRec->getLoop());
1339
1340       // If all of the other operands were loop invariant, we are done.
1341       if (Ops.size() == 1) return NewRec;
1342
1343       // Otherwise, multiply the folded AddRec by the non-liv parts.
1344       for (unsigned i = 0;; ++i)
1345         if (Ops[i] == AddRec) {
1346           Ops[i] = NewRec;
1347           break;
1348         }
1349       return getMulExpr(Ops);
1350     }
1351
1352     // Okay, if there weren't any loop invariants to be folded, check to see if
1353     // there are multiple AddRec's with the same loop induction variable being
1354     // multiplied together.  If so, we can fold them.
1355     for (unsigned OtherIdx = Idx+1;
1356          OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx)
1357       if (OtherIdx != Idx) {
1358         const SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
1359         if (AddRec->getLoop() == OtherAddRec->getLoop()) {
1360           // F * G  -->  {A,+,B} * {C,+,D}  -->  {A*C,+,F*D + G*B + B*D}
1361           const SCEVAddRecExpr *F = AddRec, *G = OtherAddRec;
1362           SCEVHandle NewStart = getMulExpr(F->getStart(),
1363                                                  G->getStart());
1364           SCEVHandle B = F->getStepRecurrence(*this);
1365           SCEVHandle D = G->getStepRecurrence(*this);
1366           SCEVHandle NewStep = getAddExpr(getMulExpr(F, D),
1367                                           getMulExpr(G, B),
1368                                           getMulExpr(B, D));
1369           SCEVHandle NewAddRec = getAddRecExpr(NewStart, NewStep,
1370                                                F->getLoop());
1371           if (Ops.size() == 2) return NewAddRec;
1372
1373           Ops.erase(Ops.begin()+Idx);
1374           Ops.erase(Ops.begin()+OtherIdx-1);
1375           Ops.push_back(NewAddRec);
1376           return getMulExpr(Ops);
1377         }
1378       }
1379
1380     // Otherwise couldn't fold anything into this recurrence.  Move onto the
1381     // next one.
1382   }
1383
1384   // Okay, it looks like we really DO need an mul expr.  Check to see if we
1385   // already have one, otherwise create a new one.
1386   std::vector<const SCEV*> SCEVOps(Ops.begin(), Ops.end());
1387   SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scMulExpr,
1388                                                                  SCEVOps)];
1389   if (Result == 0)
1390     Result = new SCEVMulExpr(Ops);
1391   return Result;
1392 }
1393
1394 /// getUDivExpr - Get a canonical multiply expression, or something simpler if
1395 /// possible.
1396 SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS,
1397                                         const SCEVHandle &RHS) {
1398   assert(getEffectiveSCEVType(LHS->getType()) ==
1399          getEffectiveSCEVType(RHS->getType()) &&
1400          "SCEVUDivExpr operand types don't match!");
1401
1402   if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
1403     if (RHSC->getValue()->equalsInt(1))
1404       return LHS;                            // X udiv 1 --> x
1405     if (RHSC->isZero())
1406       return getIntegerSCEV(0, LHS->getType()); // value is undefined
1407
1408     // Determine if the division can be folded into the operands of
1409     // its operands.
1410     // TODO: Generalize this to non-constants by using known-bits information.
1411     const Type *Ty = LHS->getType();
1412     unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros();
1413     unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ;
1414     // For non-power-of-two values, effectively round the value up to the
1415     // nearest power of two.
1416     if (!RHSC->getValue()->getValue().isPowerOf2())
1417       ++MaxShiftAmt;
1418     const IntegerType *ExtTy =
1419       IntegerType::get(getTypeSizeInBits(Ty) + MaxShiftAmt);
1420     // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
1421     if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
1422       if (const SCEVConstant *Step =
1423             dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this)))
1424         if (!Step->getValue()->getValue()
1425               .urem(RHSC->getValue()->getValue()) &&
1426             getZeroExtendExpr(AR, ExtTy) ==
1427             getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
1428                           getZeroExtendExpr(Step, ExtTy),
1429                           AR->getLoop())) {
1430           std::vector<SCEVHandle> Operands;
1431           for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i)
1432             Operands.push_back(getUDivExpr(AR->getOperand(i), RHS));
1433           return getAddRecExpr(Operands, AR->getLoop());
1434         }
1435     // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
1436     if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
1437       std::vector<SCEVHandle> Operands;
1438       for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i)
1439         Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy));
1440       if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
1441         // Find an operand that's safely divisible.
1442         for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
1443           SCEVHandle Op = M->getOperand(i);
1444           SCEVHandle Div = getUDivExpr(Op, RHSC);
1445           if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
1446             Operands = M->getOperands();
1447             Operands[i] = Div;
1448             return getMulExpr(Operands);
1449           }
1450         }
1451     }
1452     // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
1453     if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(LHS)) {
1454       std::vector<SCEVHandle> Operands;
1455       for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i)
1456         Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy));
1457       if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
1458         Operands.clear();
1459         for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
1460           SCEVHandle Op = getUDivExpr(A->getOperand(i), RHS);
1461           if (isa<SCEVUDivExpr>(Op) || getMulExpr(Op, RHS) != A->getOperand(i))
1462             break;
1463           Operands.push_back(Op);
1464         }
1465         if (Operands.size() == A->getNumOperands())
1466           return getAddExpr(Operands);
1467       }
1468     }
1469
1470     // Fold if both operands are constant.
1471     if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
1472       Constant *LHSCV = LHSC->getValue();
1473       Constant *RHSCV = RHSC->getValue();
1474       return getUnknown(ConstantExpr::getUDiv(LHSCV, RHSCV));
1475     }
1476   }
1477
1478   SCEVUDivExpr *&Result = (*SCEVUDivs)[std::make_pair(LHS, RHS)];
1479   if (Result == 0) Result = new SCEVUDivExpr(LHS, RHS);
1480   return Result;
1481 }
1482
1483
1484 /// getAddRecExpr - Get an add recurrence expression for the specified loop.
1485 /// Simplify the expression as much as possible.
1486 SCEVHandle ScalarEvolution::getAddRecExpr(const SCEVHandle &Start,
1487                                const SCEVHandle &Step, const Loop *L) {
1488   std::vector<SCEVHandle> Operands;
1489   Operands.push_back(Start);
1490   if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
1491     if (StepChrec->getLoop() == L) {
1492       Operands.insert(Operands.end(), StepChrec->op_begin(),
1493                       StepChrec->op_end());
1494       return getAddRecExpr(Operands, L);
1495     }
1496
1497   Operands.push_back(Step);
1498   return getAddRecExpr(Operands, L);
1499 }
1500
1501 /// getAddRecExpr - Get an add recurrence expression for the specified loop.
1502 /// Simplify the expression as much as possible.
1503 SCEVHandle ScalarEvolution::getAddRecExpr(std::vector<SCEVHandle> &Operands,
1504                                           const Loop *L) {
1505   if (Operands.size() == 1) return Operands[0];
1506 #ifndef NDEBUG
1507   for (unsigned i = 1, e = Operands.size(); i != e; ++i)
1508     assert(getEffectiveSCEVType(Operands[i]->getType()) ==
1509            getEffectiveSCEVType(Operands[0]->getType()) &&
1510            "SCEVAddRecExpr operand types don't match!");
1511 #endif
1512
1513   if (Operands.back()->isZero()) {
1514     Operands.pop_back();
1515     return getAddRecExpr(Operands, L);             // {X,+,0}  -->  X
1516   }
1517
1518   // Canonicalize nested AddRecs in by nesting them in order of loop depth.
1519   if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
1520     const Loop* NestedLoop = NestedAR->getLoop();
1521     if (L->getLoopDepth() < NestedLoop->getLoopDepth()) {
1522       std::vector<SCEVHandle> NestedOperands(NestedAR->op_begin(),
1523                                              NestedAR->op_end());
1524       SCEVHandle NestedARHandle(NestedAR);
1525       Operands[0] = NestedAR->getStart();
1526       NestedOperands[0] = getAddRecExpr(Operands, L);
1527       return getAddRecExpr(NestedOperands, NestedLoop);
1528     }
1529   }
1530
1531   std::vector<const SCEV*> SCEVOps(Operands.begin(), Operands.end());
1532   SCEVAddRecExpr *&Result = (*SCEVAddRecExprs)[std::make_pair(L, SCEVOps)];
1533   if (Result == 0) Result = new SCEVAddRecExpr(Operands, L);
1534   return Result;
1535 }
1536
1537 SCEVHandle ScalarEvolution::getSMaxExpr(const SCEVHandle &LHS,
1538                                         const SCEVHandle &RHS) {
1539   std::vector<SCEVHandle> Ops;
1540   Ops.push_back(LHS);
1541   Ops.push_back(RHS);
1542   return getSMaxExpr(Ops);
1543 }
1544
1545 SCEVHandle ScalarEvolution::getSMaxExpr(std::vector<SCEVHandle> Ops) {
1546   assert(!Ops.empty() && "Cannot get empty smax!");
1547   if (Ops.size() == 1) return Ops[0];
1548 #ifndef NDEBUG
1549   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1550     assert(getEffectiveSCEVType(Ops[i]->getType()) ==
1551            getEffectiveSCEVType(Ops[0]->getType()) &&
1552            "SCEVSMaxExpr operand types don't match!");
1553 #endif
1554
1555   // Sort by complexity, this groups all similar expression types together.
1556   GroupByComplexity(Ops, LI);
1557
1558   // If there are any constants, fold them together.
1559   unsigned Idx = 0;
1560   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1561     ++Idx;
1562     assert(Idx < Ops.size());
1563     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1564       // We found two constants, fold them together!
1565       ConstantInt *Fold = ConstantInt::get(
1566                               APIntOps::smax(LHSC->getValue()->getValue(),
1567                                              RHSC->getValue()->getValue()));
1568       Ops[0] = getConstant(Fold);
1569       Ops.erase(Ops.begin()+1);  // Erase the folded element
1570       if (Ops.size() == 1) return Ops[0];
1571       LHSC = cast<SCEVConstant>(Ops[0]);
1572     }
1573
1574     // If we are left with a constant -inf, strip it off.
1575     if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
1576       Ops.erase(Ops.begin());
1577       --Idx;
1578     }
1579   }
1580
1581   if (Ops.size() == 1) return Ops[0];
1582
1583   // Find the first SMax
1584   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
1585     ++Idx;
1586
1587   // Check to see if one of the operands is an SMax. If so, expand its operands
1588   // onto our operand list, and recurse to simplify.
1589   if (Idx < Ops.size()) {
1590     bool DeletedSMax = false;
1591     while (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
1592       Ops.insert(Ops.end(), SMax->op_begin(), SMax->op_end());
1593       Ops.erase(Ops.begin()+Idx);
1594       DeletedSMax = true;
1595     }
1596
1597     if (DeletedSMax)
1598       return getSMaxExpr(Ops);
1599   }
1600
1601   // Okay, check to see if the same value occurs in the operand list twice.  If
1602   // so, delete one.  Since we sorted the list, these values are required to
1603   // be adjacent.
1604   for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
1605     if (Ops[i] == Ops[i+1]) {      //  X smax Y smax Y  -->  X smax Y
1606       Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
1607       --i; --e;
1608     }
1609
1610   if (Ops.size() == 1) return Ops[0];
1611
1612   assert(!Ops.empty() && "Reduced smax down to nothing!");
1613
1614   // Okay, it looks like we really DO need an smax expr.  Check to see if we
1615   // already have one, otherwise create a new one.
1616   std::vector<const SCEV*> SCEVOps(Ops.begin(), Ops.end());
1617   SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scSMaxExpr,
1618                                                                  SCEVOps)];
1619   if (Result == 0) Result = new SCEVSMaxExpr(Ops);
1620   return Result;
1621 }
1622
1623 SCEVHandle ScalarEvolution::getUMaxExpr(const SCEVHandle &LHS,
1624                                         const SCEVHandle &RHS) {
1625   std::vector<SCEVHandle> Ops;
1626   Ops.push_back(LHS);
1627   Ops.push_back(RHS);
1628   return getUMaxExpr(Ops);
1629 }
1630
1631 SCEVHandle ScalarEvolution::getUMaxExpr(std::vector<SCEVHandle> Ops) {
1632   assert(!Ops.empty() && "Cannot get empty umax!");
1633   if (Ops.size() == 1) return Ops[0];
1634 #ifndef NDEBUG
1635   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1636     assert(getEffectiveSCEVType(Ops[i]->getType()) ==
1637            getEffectiveSCEVType(Ops[0]->getType()) &&
1638            "SCEVUMaxExpr operand types don't match!");
1639 #endif
1640
1641   // Sort by complexity, this groups all similar expression types together.
1642   GroupByComplexity(Ops, LI);
1643
1644   // If there are any constants, fold them together.
1645   unsigned Idx = 0;
1646   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1647     ++Idx;
1648     assert(Idx < Ops.size());
1649     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1650       // We found two constants, fold them together!
1651       ConstantInt *Fold = ConstantInt::get(
1652                               APIntOps::umax(LHSC->getValue()->getValue(),
1653                                              RHSC->getValue()->getValue()));
1654       Ops[0] = getConstant(Fold);
1655       Ops.erase(Ops.begin()+1);  // Erase the folded element
1656       if (Ops.size() == 1) return Ops[0];
1657       LHSC = cast<SCEVConstant>(Ops[0]);
1658     }
1659
1660     // If we are left with a constant zero, strip it off.
1661     if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) {
1662       Ops.erase(Ops.begin());
1663       --Idx;
1664     }
1665   }
1666
1667   if (Ops.size() == 1) return Ops[0];
1668
1669   // Find the first UMax
1670   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr)
1671     ++Idx;
1672
1673   // Check to see if one of the operands is a UMax. If so, expand its operands
1674   // onto our operand list, and recurse to simplify.
1675   if (Idx < Ops.size()) {
1676     bool DeletedUMax = false;
1677     while (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) {
1678       Ops.insert(Ops.end(), UMax->op_begin(), UMax->op_end());
1679       Ops.erase(Ops.begin()+Idx);
1680       DeletedUMax = true;
1681     }
1682
1683     if (DeletedUMax)
1684       return getUMaxExpr(Ops);
1685   }
1686
1687   // Okay, check to see if the same value occurs in the operand list twice.  If
1688   // so, delete one.  Since we sorted the list, these values are required to
1689   // be adjacent.
1690   for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
1691     if (Ops[i] == Ops[i+1]) {      //  X umax Y umax Y  -->  X umax Y
1692       Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
1693       --i; --e;
1694     }
1695
1696   if (Ops.size() == 1) return Ops[0];
1697
1698   assert(!Ops.empty() && "Reduced umax down to nothing!");
1699
1700   // Okay, it looks like we really DO need a umax expr.  Check to see if we
1701   // already have one, otherwise create a new one.
1702   std::vector<const SCEV*> SCEVOps(Ops.begin(), Ops.end());
1703   SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scUMaxExpr,
1704                                                                  SCEVOps)];
1705   if (Result == 0) Result = new SCEVUMaxExpr(Ops);
1706   return Result;
1707 }
1708
1709 SCEVHandle ScalarEvolution::getUnknown(Value *V) {
1710   if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
1711     return getConstant(CI);
1712   if (isa<ConstantPointerNull>(V))
1713     return getIntegerSCEV(0, V->getType());
1714   SCEVUnknown *&Result = (*SCEVUnknowns)[V];
1715   if (Result == 0) Result = new SCEVUnknown(V);
1716   return Result;
1717 }
1718
1719 //===----------------------------------------------------------------------===//
1720 //            Basic SCEV Analysis and PHI Idiom Recognition Code
1721 //
1722
1723 /// isSCEVable - Test if values of the given type are analyzable within
1724 /// the SCEV framework. This primarily includes integer types, and it
1725 /// can optionally include pointer types if the ScalarEvolution class
1726 /// has access to target-specific information.
1727 bool ScalarEvolution::isSCEVable(const Type *Ty) const {
1728   // Integers are always SCEVable.
1729   if (Ty->isInteger())
1730     return true;
1731
1732   // Pointers are SCEVable if TargetData information is available
1733   // to provide pointer size information.
1734   if (isa<PointerType>(Ty))
1735     return TD != NULL;
1736
1737   // Otherwise it's not SCEVable.
1738   return false;
1739 }
1740
1741 /// getTypeSizeInBits - Return the size in bits of the specified type,
1742 /// for which isSCEVable must return true.
1743 uint64_t ScalarEvolution::getTypeSizeInBits(const Type *Ty) const {
1744   assert(isSCEVable(Ty) && "Type is not SCEVable!");
1745
1746   // If we have a TargetData, use it!
1747   if (TD)
1748     return TD->getTypeSizeInBits(Ty);
1749
1750   // Otherwise, we support only integer types.
1751   assert(Ty->isInteger() && "isSCEVable permitted a non-SCEVable type!");
1752   return Ty->getPrimitiveSizeInBits();
1753 }
1754
1755 /// getEffectiveSCEVType - Return a type with the same bitwidth as
1756 /// the given type and which represents how SCEV will treat the given
1757 /// type, for which isSCEVable must return true. For pointer types,
1758 /// this is the pointer-sized integer type.
1759 const Type *ScalarEvolution::getEffectiveSCEVType(const Type *Ty) const {
1760   assert(isSCEVable(Ty) && "Type is not SCEVable!");
1761
1762   if (Ty->isInteger())
1763     return Ty;
1764
1765   assert(isa<PointerType>(Ty) && "Unexpected non-pointer non-integer type!");
1766   return TD->getIntPtrType();
1767 }
1768
1769 SCEVHandle ScalarEvolution::getCouldNotCompute() {
1770   return UnknownValue;
1771 }
1772
1773 /// hasSCEV - Return true if the SCEV for this value has already been
1774 /// computed.
1775 bool ScalarEvolution::hasSCEV(Value *V) const {
1776   return Scalars.count(V);
1777 }
1778
1779 /// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
1780 /// expression and create a new one.
1781 SCEVHandle ScalarEvolution::getSCEV(Value *V) {
1782   assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
1783
1784   std::map<SCEVCallbackVH, SCEVHandle>::iterator I = Scalars.find(V);
1785   if (I != Scalars.end()) return I->second;
1786   SCEVHandle S = createSCEV(V);
1787   Scalars.insert(std::make_pair(SCEVCallbackVH(V, this), S));
1788   return S;
1789 }
1790
1791 /// getIntegerSCEV - Given an integer or FP type, create a constant for the
1792 /// specified signed integer value and return a SCEV for the constant.
1793 SCEVHandle ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) {
1794   Ty = getEffectiveSCEVType(Ty);
1795   Constant *C;
1796   if (Val == 0)
1797     C = Constant::getNullValue(Ty);
1798   else if (Ty->isFloatingPoint())
1799     C = ConstantFP::get(APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle :
1800                                 APFloat::IEEEdouble, Val));
1801   else
1802     C = ConstantInt::get(Ty, Val);
1803   return getUnknown(C);
1804 }
1805
1806 /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
1807 ///
1808 SCEVHandle ScalarEvolution::getNegativeSCEV(const SCEVHandle &V) {
1809   if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
1810     return getUnknown(ConstantExpr::getNeg(VC->getValue()));
1811
1812   const Type *Ty = V->getType();
1813   Ty = getEffectiveSCEVType(Ty);
1814   return getMulExpr(V, getConstant(ConstantInt::getAllOnesValue(Ty)));
1815 }
1816
1817 /// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
1818 SCEVHandle ScalarEvolution::getNotSCEV(const SCEVHandle &V) {
1819   if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
1820     return getUnknown(ConstantExpr::getNot(VC->getValue()));
1821
1822   const Type *Ty = V->getType();
1823   Ty = getEffectiveSCEVType(Ty);
1824   SCEVHandle AllOnes = getConstant(ConstantInt::getAllOnesValue(Ty));
1825   return getMinusSCEV(AllOnes, V);
1826 }
1827
1828 /// getMinusSCEV - Return a SCEV corresponding to LHS - RHS.
1829 ///
1830 SCEVHandle ScalarEvolution::getMinusSCEV(const SCEVHandle &LHS,
1831                                          const SCEVHandle &RHS) {
1832   // X - Y --> X + -Y
1833   return getAddExpr(LHS, getNegativeSCEV(RHS));
1834 }
1835
1836 /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
1837 /// input value to the specified type.  If the type must be extended, it is zero
1838 /// extended.
1839 SCEVHandle
1840 ScalarEvolution::getTruncateOrZeroExtend(const SCEVHandle &V,
1841                                          const Type *Ty) {
1842   const Type *SrcTy = V->getType();
1843   assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
1844          (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
1845          "Cannot truncate or zero extend with non-integer arguments!");
1846   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
1847     return V;  // No conversion
1848   if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
1849     return getTruncateExpr(V, Ty);
1850   return getZeroExtendExpr(V, Ty);
1851 }
1852
1853 /// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the
1854 /// input value to the specified type.  If the type must be extended, it is sign
1855 /// extended.
1856 SCEVHandle
1857 ScalarEvolution::getTruncateOrSignExtend(const SCEVHandle &V,
1858                                          const Type *Ty) {
1859   const Type *SrcTy = V->getType();
1860   assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
1861          (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
1862          "Cannot truncate or zero extend with non-integer arguments!");
1863   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
1864     return V;  // No conversion
1865   if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
1866     return getTruncateExpr(V, Ty);
1867   return getSignExtendExpr(V, Ty);
1868 }
1869
1870 /// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the
1871 /// input value to the specified type.  If the type must be extended, it is zero
1872 /// extended.  The conversion must not be narrowing.
1873 SCEVHandle
1874 ScalarEvolution::getNoopOrZeroExtend(const SCEVHandle &V, const Type *Ty) {
1875   const Type *SrcTy = V->getType();
1876   assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
1877          (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
1878          "Cannot noop or zero extend with non-integer arguments!");
1879   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
1880          "getNoopOrZeroExtend cannot truncate!");
1881   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
1882     return V;  // No conversion
1883   return getZeroExtendExpr(V, Ty);
1884 }
1885
1886 /// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the
1887 /// input value to the specified type.  If the type must be extended, it is sign
1888 /// extended.  The conversion must not be narrowing.
1889 SCEVHandle
1890 ScalarEvolution::getNoopOrSignExtend(const SCEVHandle &V, const Type *Ty) {
1891   const Type *SrcTy = V->getType();
1892   assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
1893          (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
1894          "Cannot noop or sign extend with non-integer arguments!");
1895   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
1896          "getNoopOrSignExtend cannot truncate!");
1897   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
1898     return V;  // No conversion
1899   return getSignExtendExpr(V, Ty);
1900 }
1901
1902 /// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the
1903 /// input value to the specified type.  The conversion must not be widening.
1904 SCEVHandle
1905 ScalarEvolution::getTruncateOrNoop(const SCEVHandle &V, const Type *Ty) {
1906   const Type *SrcTy = V->getType();
1907   assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
1908          (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
1909          "Cannot truncate or noop with non-integer arguments!");
1910   assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
1911          "getTruncateOrNoop cannot extend!");
1912   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
1913     return V;  // No conversion
1914   return getTruncateExpr(V, Ty);
1915 }
1916
1917 /// ReplaceSymbolicValueWithConcrete - This looks up the computed SCEV value for
1918 /// the specified instruction and replaces any references to the symbolic value
1919 /// SymName with the specified value.  This is used during PHI resolution.
1920 void ScalarEvolution::
1921 ReplaceSymbolicValueWithConcrete(Instruction *I, const SCEVHandle &SymName,
1922                                  const SCEVHandle &NewVal) {
1923   std::map<SCEVCallbackVH, SCEVHandle>::iterator SI =
1924     Scalars.find(SCEVCallbackVH(I, this));
1925   if (SI == Scalars.end()) return;
1926
1927   SCEVHandle NV =
1928     SI->second->replaceSymbolicValuesWithConcrete(SymName, NewVal, *this);
1929   if (NV == SI->second) return;  // No change.
1930
1931   SI->second = NV;       // Update the scalars map!
1932
1933   // Any instruction values that use this instruction might also need to be
1934   // updated!
1935   for (Value::use_iterator UI = I->use_begin(), E = I->use_end();
1936        UI != E; ++UI)
1937     ReplaceSymbolicValueWithConcrete(cast<Instruction>(*UI), SymName, NewVal);
1938 }
1939
1940 /// createNodeForPHI - PHI nodes have two cases.  Either the PHI node exists in
1941 /// a loop header, making it a potential recurrence, or it doesn't.
1942 ///
1943 SCEVHandle ScalarEvolution::createNodeForPHI(PHINode *PN) {
1944   if (PN->getNumIncomingValues() == 2)  // The loops have been canonicalized.
1945     if (const Loop *L = LI->getLoopFor(PN->getParent()))
1946       if (L->getHeader() == PN->getParent()) {
1947         // If it lives in the loop header, it has two incoming values, one
1948         // from outside the loop, and one from inside.
1949         unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0));
1950         unsigned BackEdge     = IncomingEdge^1;
1951
1952         // While we are analyzing this PHI node, handle its value symbolically.
1953         SCEVHandle SymbolicName = getUnknown(PN);
1954         assert(Scalars.find(PN) == Scalars.end() &&
1955                "PHI node already processed?");
1956         Scalars.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName));
1957
1958         // Using this symbolic name for the PHI, analyze the value coming around
1959         // the back-edge.
1960         SCEVHandle BEValue = getSCEV(PN->getIncomingValue(BackEdge));
1961
1962         // NOTE: If BEValue is loop invariant, we know that the PHI node just
1963         // has a special value for the first iteration of the loop.
1964
1965         // If the value coming around the backedge is an add with the symbolic
1966         // value we just inserted, then we found a simple induction variable!
1967         if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
1968           // If there is a single occurrence of the symbolic value, replace it
1969           // with a recurrence.
1970           unsigned FoundIndex = Add->getNumOperands();
1971           for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
1972             if (Add->getOperand(i) == SymbolicName)
1973               if (FoundIndex == e) {
1974                 FoundIndex = i;
1975                 break;
1976               }
1977
1978           if (FoundIndex != Add->getNumOperands()) {
1979             // Create an add with everything but the specified operand.
1980             std::vector<SCEVHandle> Ops;
1981             for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
1982               if (i != FoundIndex)
1983                 Ops.push_back(Add->getOperand(i));
1984             SCEVHandle Accum = getAddExpr(Ops);
1985
1986             // This is not a valid addrec if the step amount is varying each
1987             // loop iteration, but is not itself an addrec in this loop.
1988             if (Accum->isLoopInvariant(L) ||
1989                 (isa<SCEVAddRecExpr>(Accum) &&
1990                  cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
1991               SCEVHandle StartVal = getSCEV(PN->getIncomingValue(IncomingEdge));
1992               SCEVHandle PHISCEV  = getAddRecExpr(StartVal, Accum, L);
1993
1994               // Okay, for the entire analysis of this edge we assumed the PHI
1995               // to be symbolic.  We now need to go back and update all of the
1996               // entries for the scalars that use the PHI (except for the PHI
1997               // itself) to use the new analyzed value instead of the "symbolic"
1998               // value.
1999               ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV);
2000               return PHISCEV;
2001             }
2002           }
2003         } else if (const SCEVAddRecExpr *AddRec =
2004                      dyn_cast<SCEVAddRecExpr>(BEValue)) {
2005           // Otherwise, this could be a loop like this:
2006           //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
2007           // In this case, j = {1,+,1}  and BEValue is j.
2008           // Because the other in-value of i (0) fits the evolution of BEValue
2009           // i really is an addrec evolution.
2010           if (AddRec->getLoop() == L && AddRec->isAffine()) {
2011             SCEVHandle StartVal = getSCEV(PN->getIncomingValue(IncomingEdge));
2012
2013             // If StartVal = j.start - j.stride, we can use StartVal as the
2014             // initial step of the addrec evolution.
2015             if (StartVal == getMinusSCEV(AddRec->getOperand(0),
2016                                             AddRec->getOperand(1))) {
2017               SCEVHandle PHISCEV = 
2018                  getAddRecExpr(StartVal, AddRec->getOperand(1), L);
2019
2020               // Okay, for the entire analysis of this edge we assumed the PHI
2021               // to be symbolic.  We now need to go back and update all of the
2022               // entries for the scalars that use the PHI (except for the PHI
2023               // itself) to use the new analyzed value instead of the "symbolic"
2024               // value.
2025               ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV);
2026               return PHISCEV;
2027             }
2028           }
2029         }
2030
2031         return SymbolicName;
2032       }
2033
2034   // If it's not a loop phi, we can't handle it yet.
2035   return getUnknown(PN);
2036 }
2037
2038 /// createNodeForGEP - Expand GEP instructions into add and multiply
2039 /// operations. This allows them to be analyzed by regular SCEV code.
2040 ///
2041 SCEVHandle ScalarEvolution::createNodeForGEP(User *GEP) {
2042
2043   const Type *IntPtrTy = TD->getIntPtrType();
2044   Value *Base = GEP->getOperand(0);
2045   // Don't attempt to analyze GEPs over unsized objects.
2046   if (!cast<PointerType>(Base->getType())->getElementType()->isSized())
2047     return getUnknown(GEP);
2048   SCEVHandle TotalOffset = getIntegerSCEV(0, IntPtrTy);
2049   gep_type_iterator GTI = gep_type_begin(GEP);
2050   for (GetElementPtrInst::op_iterator I = next(GEP->op_begin()),
2051                                       E = GEP->op_end();
2052        I != E; ++I) {
2053     Value *Index = *I;
2054     // Compute the (potentially symbolic) offset in bytes for this index.
2055     if (const StructType *STy = dyn_cast<StructType>(*GTI++)) {
2056       // For a struct, add the member offset.
2057       const StructLayout &SL = *TD->getStructLayout(STy);
2058       unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue();
2059       uint64_t Offset = SL.getElementOffset(FieldNo);
2060       TotalOffset = getAddExpr(TotalOffset,
2061                                   getIntegerSCEV(Offset, IntPtrTy));
2062     } else {
2063       // For an array, add the element offset, explicitly scaled.
2064       SCEVHandle LocalOffset = getSCEV(Index);
2065       if (!isa<PointerType>(LocalOffset->getType()))
2066         // Getelementptr indicies are signed.
2067         LocalOffset = getTruncateOrSignExtend(LocalOffset,
2068                                               IntPtrTy);
2069       LocalOffset =
2070         getMulExpr(LocalOffset,
2071                    getIntegerSCEV(TD->getTypeAllocSize(*GTI),
2072                                   IntPtrTy));
2073       TotalOffset = getAddExpr(TotalOffset, LocalOffset);
2074     }
2075   }
2076   return getAddExpr(getSCEV(Base), TotalOffset);
2077 }
2078
2079 /// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
2080 /// guaranteed to end in (at every loop iteration).  It is, at the same time,
2081 /// the minimum number of times S is divisible by 2.  For example, given {4,+,8}
2082 /// it returns 2.  If S is guaranteed to be 0, it returns the bitwidth of S.
2083 static uint32_t GetMinTrailingZeros(SCEVHandle S, const ScalarEvolution &SE) {
2084   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
2085     return C->getValue()->getValue().countTrailingZeros();
2086
2087   if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
2088     return std::min(GetMinTrailingZeros(T->getOperand(), SE),
2089                     (uint32_t)SE.getTypeSizeInBits(T->getType()));
2090
2091   if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
2092     uint32_t OpRes = GetMinTrailingZeros(E->getOperand(), SE);
2093     return OpRes == SE.getTypeSizeInBits(E->getOperand()->getType()) ?
2094              SE.getTypeSizeInBits(E->getType()) : OpRes;
2095   }
2096
2097   if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
2098     uint32_t OpRes = GetMinTrailingZeros(E->getOperand(), SE);
2099     return OpRes == SE.getTypeSizeInBits(E->getOperand()->getType()) ?
2100              SE.getTypeSizeInBits(E->getType()) : OpRes;
2101   }
2102
2103   if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
2104     // The result is the min of all operands results.
2105     uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0), SE);
2106     for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
2107       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i), SE));
2108     return MinOpRes;
2109   }
2110
2111   if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
2112     // The result is the sum of all operands results.
2113     uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0), SE);
2114     uint32_t BitWidth = SE.getTypeSizeInBits(M->getType());
2115     for (unsigned i = 1, e = M->getNumOperands();
2116          SumOpRes != BitWidth && i != e; ++i)
2117       SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i), SE),
2118                           BitWidth);
2119     return SumOpRes;
2120   }
2121
2122   if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
2123     // The result is the min of all operands results.
2124     uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0), SE);
2125     for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
2126       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i), SE));
2127     return MinOpRes;
2128   }
2129
2130   if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
2131     // The result is the min of all operands results.
2132     uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0), SE);
2133     for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
2134       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i), SE));
2135     return MinOpRes;
2136   }
2137
2138   if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
2139     // The result is the min of all operands results.
2140     uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0), SE);
2141     for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
2142       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i), SE));
2143     return MinOpRes;
2144   }
2145
2146   // SCEVUDivExpr, SCEVUnknown
2147   return 0;
2148 }
2149
2150 /// createSCEV - We know that there is no SCEV for the specified value.
2151 /// Analyze the expression.
2152 ///
2153 SCEVHandle ScalarEvolution::createSCEV(Value *V) {
2154   if (!isSCEVable(V->getType()))
2155     return getUnknown(V);
2156
2157   unsigned Opcode = Instruction::UserOp1;
2158   if (Instruction *I = dyn_cast<Instruction>(V))
2159     Opcode = I->getOpcode();
2160   else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
2161     Opcode = CE->getOpcode();
2162   else
2163     return getUnknown(V);
2164
2165   User *U = cast<User>(V);
2166   switch (Opcode) {
2167   case Instruction::Add:
2168     return getAddExpr(getSCEV(U->getOperand(0)),
2169                       getSCEV(U->getOperand(1)));
2170   case Instruction::Mul:
2171     return getMulExpr(getSCEV(U->getOperand(0)),
2172                       getSCEV(U->getOperand(1)));
2173   case Instruction::UDiv:
2174     return getUDivExpr(getSCEV(U->getOperand(0)),
2175                        getSCEV(U->getOperand(1)));
2176   case Instruction::Sub:
2177     return getMinusSCEV(getSCEV(U->getOperand(0)),
2178                         getSCEV(U->getOperand(1)));
2179   case Instruction::And:
2180     // For an expression like x&255 that merely masks off the high bits,
2181     // use zext(trunc(x)) as the SCEV expression.
2182     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
2183       if (CI->isNullValue())
2184         return getSCEV(U->getOperand(1));
2185       if (CI->isAllOnesValue())
2186         return getSCEV(U->getOperand(0));
2187       const APInt &A = CI->getValue();
2188       unsigned Ones = A.countTrailingOnes();
2189       if (APIntOps::isMask(Ones, A))
2190         return
2191           getZeroExtendExpr(getTruncateExpr(getSCEV(U->getOperand(0)),
2192                                             IntegerType::get(Ones)),
2193                             U->getType());
2194     }
2195     break;
2196   case Instruction::Or:
2197     // If the RHS of the Or is a constant, we may have something like:
2198     // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
2199     // optimizations will transparently handle this case.
2200     //
2201     // In order for this transformation to be safe, the LHS must be of the
2202     // form X*(2^n) and the Or constant must be less than 2^n.
2203     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
2204       SCEVHandle LHS = getSCEV(U->getOperand(0));
2205       const APInt &CIVal = CI->getValue();
2206       if (GetMinTrailingZeros(LHS, *this) >=
2207           (CIVal.getBitWidth() - CIVal.countLeadingZeros()))
2208         return getAddExpr(LHS, getSCEV(U->getOperand(1)));
2209     }
2210     break;
2211   case Instruction::Xor:
2212     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
2213       // If the RHS of the xor is a signbit, then this is just an add.
2214       // Instcombine turns add of signbit into xor as a strength reduction step.
2215       if (CI->getValue().isSignBit())
2216         return getAddExpr(getSCEV(U->getOperand(0)),
2217                           getSCEV(U->getOperand(1)));
2218
2219       // If the RHS of xor is -1, then this is a not operation.
2220       if (CI->isAllOnesValue())
2221         return getNotSCEV(getSCEV(U->getOperand(0)));
2222
2223       // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
2224       // This is a variant of the check for xor with -1, and it handles
2225       // the case where instcombine has trimmed non-demanded bits out
2226       // of an xor with -1.
2227       if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0)))
2228         if (ConstantInt *LCI = dyn_cast<ConstantInt>(BO->getOperand(1)))
2229           if (BO->getOpcode() == Instruction::And &&
2230               LCI->getValue() == CI->getValue())
2231             if (const SCEVZeroExtendExpr *Z =
2232                   dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0))))
2233               return getZeroExtendExpr(getNotSCEV(Z->getOperand()),
2234                                        U->getType());
2235     }
2236     break;
2237
2238   case Instruction::Shl:
2239     // Turn shift left of a constant amount into a multiply.
2240     if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
2241       uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
2242       Constant *X = ConstantInt::get(
2243         APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
2244       return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
2245     }
2246     break;
2247
2248   case Instruction::LShr:
2249     // Turn logical shift right of a constant into a unsigned divide.
2250     if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
2251       uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
2252       Constant *X = ConstantInt::get(
2253         APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
2254       return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X));
2255     }
2256     break;
2257
2258   case Instruction::AShr:
2259     // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
2260     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1)))
2261       if (Instruction *L = dyn_cast<Instruction>(U->getOperand(0)))
2262         if (L->getOpcode() == Instruction::Shl &&
2263             L->getOperand(1) == U->getOperand(1)) {
2264           unsigned BitWidth = getTypeSizeInBits(U->getType());
2265           uint64_t Amt = BitWidth - CI->getZExtValue();
2266           if (Amt == BitWidth)
2267             return getSCEV(L->getOperand(0));       // shift by zero --> noop
2268           if (Amt > BitWidth)
2269             return getIntegerSCEV(0, U->getType()); // value is undefined
2270           return
2271             getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)),
2272                                                       IntegerType::get(Amt)),
2273                                  U->getType());
2274         }
2275     break;
2276
2277   case Instruction::Trunc:
2278     return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
2279
2280   case Instruction::ZExt:
2281     return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
2282
2283   case Instruction::SExt:
2284     return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
2285
2286   case Instruction::BitCast:
2287     // BitCasts are no-op casts so we just eliminate the cast.
2288     if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
2289       return getSCEV(U->getOperand(0));
2290     break;
2291
2292   case Instruction::IntToPtr:
2293     if (!TD) break; // Without TD we can't analyze pointers.
2294     return getTruncateOrZeroExtend(getSCEV(U->getOperand(0)),
2295                                    TD->getIntPtrType());
2296
2297   case Instruction::PtrToInt:
2298     if (!TD) break; // Without TD we can't analyze pointers.
2299     return getTruncateOrZeroExtend(getSCEV(U->getOperand(0)),
2300                                    U->getType());
2301
2302   case Instruction::GetElementPtr:
2303     if (!TD) break; // Without TD we can't analyze pointers.
2304     return createNodeForGEP(U);
2305
2306   case Instruction::PHI:
2307     return createNodeForPHI(cast<PHINode>(U));
2308
2309   case Instruction::Select:
2310     // This could be a smax or umax that was lowered earlier.
2311     // Try to recover it.
2312     if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
2313       Value *LHS = ICI->getOperand(0);
2314       Value *RHS = ICI->getOperand(1);
2315       switch (ICI->getPredicate()) {
2316       case ICmpInst::ICMP_SLT:
2317       case ICmpInst::ICMP_SLE:
2318         std::swap(LHS, RHS);
2319         // fall through
2320       case ICmpInst::ICMP_SGT:
2321       case ICmpInst::ICMP_SGE:
2322         if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
2323           return getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
2324         else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
2325           // ~smax(~x, ~y) == smin(x, y).
2326           return getNotSCEV(getSMaxExpr(
2327                                    getNotSCEV(getSCEV(LHS)),
2328                                    getNotSCEV(getSCEV(RHS))));
2329         break;
2330       case ICmpInst::ICMP_ULT:
2331       case ICmpInst::ICMP_ULE:
2332         std::swap(LHS, RHS);
2333         // fall through
2334       case ICmpInst::ICMP_UGT:
2335       case ICmpInst::ICMP_UGE:
2336         if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
2337           return getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
2338         else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
2339           // ~umax(~x, ~y) == umin(x, y)
2340           return getNotSCEV(getUMaxExpr(getNotSCEV(getSCEV(LHS)),
2341                                         getNotSCEV(getSCEV(RHS))));
2342         break;
2343       default:
2344         break;
2345       }
2346     }
2347
2348   default: // We cannot analyze this expression.
2349     break;
2350   }
2351
2352   return getUnknown(V);
2353 }
2354
2355
2356
2357 //===----------------------------------------------------------------------===//
2358 //                   Iteration Count Computation Code
2359 //
2360
2361 /// getBackedgeTakenCount - If the specified loop has a predictable
2362 /// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute
2363 /// object. The backedge-taken count is the number of times the loop header
2364 /// will be branched to from within the loop. This is one less than the
2365 /// trip count of the loop, since it doesn't count the first iteration,
2366 /// when the header is branched to from outside the loop.
2367 ///
2368 /// Note that it is not valid to call this method on a loop without a
2369 /// loop-invariant backedge-taken count (see
2370 /// hasLoopInvariantBackedgeTakenCount).
2371 ///
2372 SCEVHandle ScalarEvolution::getBackedgeTakenCount(const Loop *L) {
2373   return getBackedgeTakenInfo(L).Exact;
2374 }
2375
2376 /// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except
2377 /// return the least SCEV value that is known never to be less than the
2378 /// actual backedge taken count.
2379 SCEVHandle ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) {
2380   return getBackedgeTakenInfo(L).Max;
2381 }
2382
2383 const ScalarEvolution::BackedgeTakenInfo &
2384 ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
2385   // Initially insert a CouldNotCompute for this loop. If the insertion
2386   // succeeds, procede to actually compute a backedge-taken count and
2387   // update the value. The temporary CouldNotCompute value tells SCEV
2388   // code elsewhere that it shouldn't attempt to request a new
2389   // backedge-taken count, which could result in infinite recursion.
2390   std::pair<std::map<const Loop*, BackedgeTakenInfo>::iterator, bool> Pair =
2391     BackedgeTakenCounts.insert(std::make_pair(L, getCouldNotCompute()));
2392   if (Pair.second) {
2393     BackedgeTakenInfo ItCount = ComputeBackedgeTakenCount(L);
2394     if (ItCount.Exact != UnknownValue) {
2395       assert(ItCount.Exact->isLoopInvariant(L) &&
2396              ItCount.Max->isLoopInvariant(L) &&
2397              "Computed trip count isn't loop invariant for loop!");
2398       ++NumTripCountsComputed;
2399
2400       // Update the value in the map.
2401       Pair.first->second = ItCount;
2402     } else if (isa<PHINode>(L->getHeader()->begin())) {
2403       // Only count loops that have phi nodes as not being computable.
2404       ++NumTripCountsNotComputed;
2405     }
2406
2407     // Now that we know more about the trip count for this loop, forget any
2408     // existing SCEV values for PHI nodes in this loop since they are only
2409     // conservative estimates made without the benefit
2410     // of trip count information.
2411     if (ItCount.hasAnyInfo())
2412       forgetLoopPHIs(L);
2413   }
2414   return Pair.first->second;
2415 }
2416
2417 /// forgetLoopBackedgeTakenCount - This method should be called by the
2418 /// client when it has changed a loop in a way that may effect
2419 /// ScalarEvolution's ability to compute a trip count, or if the loop
2420 /// is deleted.
2421 void ScalarEvolution::forgetLoopBackedgeTakenCount(const Loop *L) {
2422   BackedgeTakenCounts.erase(L);
2423   forgetLoopPHIs(L);
2424 }
2425
2426 /// forgetLoopPHIs - Delete the memoized SCEVs associated with the
2427 /// PHI nodes in the given loop. This is used when the trip count of
2428 /// the loop may have changed.
2429 void ScalarEvolution::forgetLoopPHIs(const Loop *L) {
2430   BasicBlock *Header = L->getHeader();
2431
2432   // Push all Loop-header PHIs onto the Worklist stack, except those
2433   // that are presently represented via a SCEVUnknown. SCEVUnknown for
2434   // a PHI either means that it has an unrecognized structure, or it's
2435   // a PHI that's in the progress of being computed by createNodeForPHI.
2436   // In the former case, additional loop trip count information isn't
2437   // going to change anything. In the later case, createNodeForPHI will
2438   // perform the necessary updates on its own when it gets to that point.
2439   SmallVector<Instruction *, 16> Worklist;
2440   for (BasicBlock::iterator I = Header->begin();
2441        PHINode *PN = dyn_cast<PHINode>(I); ++I) {
2442     std::map<SCEVCallbackVH, SCEVHandle>::iterator It = Scalars.find((Value*)I);
2443     if (It != Scalars.end() && !isa<SCEVUnknown>(It->second))
2444       Worklist.push_back(PN);
2445   }
2446
2447   while (!Worklist.empty()) {
2448     Instruction *I = Worklist.pop_back_val();
2449     if (Scalars.erase(I))
2450       for (Value::use_iterator UI = I->use_begin(), UE = I->use_end();
2451            UI != UE; ++UI)
2452         Worklist.push_back(cast<Instruction>(UI));
2453   }
2454 }
2455
2456 /// ComputeBackedgeTakenCount - Compute the number of times the backedge
2457 /// of the specified loop will execute.
2458 ScalarEvolution::BackedgeTakenInfo
2459 ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) {
2460   // If the loop has a non-one exit block count, we can't analyze it.
2461   SmallVector<BasicBlock*, 8> ExitBlocks;
2462   L->getExitBlocks(ExitBlocks);
2463   if (ExitBlocks.size() != 1) return UnknownValue;
2464
2465   // Okay, there is one exit block.  Try to find the condition that causes the
2466   // loop to be exited.
2467   BasicBlock *ExitBlock = ExitBlocks[0];
2468
2469   BasicBlock *ExitingBlock = 0;
2470   for (pred_iterator PI = pred_begin(ExitBlock), E = pred_end(ExitBlock);
2471        PI != E; ++PI)
2472     if (L->contains(*PI)) {
2473       if (ExitingBlock == 0)
2474         ExitingBlock = *PI;
2475       else
2476         return UnknownValue;   // More than one block exiting!
2477     }
2478   assert(ExitingBlock && "No exits from loop, something is broken!");
2479
2480   // Okay, we've computed the exiting block.  See what condition causes us to
2481   // exit.
2482   //
2483   // FIXME: we should be able to handle switch instructions (with a single exit)
2484   BranchInst *ExitBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
2485   if (ExitBr == 0) return UnknownValue;
2486   assert(ExitBr->isConditional() && "If unconditional, it can't be in loop!");
2487   
2488   // At this point, we know we have a conditional branch that determines whether
2489   // the loop is exited.  However, we don't know if the branch is executed each
2490   // time through the loop.  If not, then the execution count of the branch will
2491   // not be equal to the trip count of the loop.
2492   //
2493   // Currently we check for this by checking to see if the Exit branch goes to
2494   // the loop header.  If so, we know it will always execute the same number of
2495   // times as the loop.  We also handle the case where the exit block *is* the
2496   // loop header.  This is common for un-rotated loops.  More extensive analysis
2497   // could be done to handle more cases here.
2498   if (ExitBr->getSuccessor(0) != L->getHeader() &&
2499       ExitBr->getSuccessor(1) != L->getHeader() &&
2500       ExitBr->getParent() != L->getHeader())
2501     return UnknownValue;
2502   
2503   ICmpInst *ExitCond = dyn_cast<ICmpInst>(ExitBr->getCondition());
2504
2505   // If it's not an integer or pointer comparison then compute it the hard way.
2506   if (ExitCond == 0)
2507     return ComputeBackedgeTakenCountExhaustively(L, ExitBr->getCondition(),
2508                                           ExitBr->getSuccessor(0) == ExitBlock);
2509
2510   // If the condition was exit on true, convert the condition to exit on false
2511   ICmpInst::Predicate Cond;
2512   if (ExitBr->getSuccessor(1) == ExitBlock)
2513     Cond = ExitCond->getPredicate();
2514   else
2515     Cond = ExitCond->getInversePredicate();
2516
2517   // Handle common loops like: for (X = "string"; *X; ++X)
2518   if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
2519     if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
2520       SCEVHandle ItCnt =
2521         ComputeLoadConstantCompareBackedgeTakenCount(LI, RHS, L, Cond);
2522       if (!isa<SCEVCouldNotCompute>(ItCnt)) return ItCnt;
2523     }
2524
2525   SCEVHandle LHS = getSCEV(ExitCond->getOperand(0));
2526   SCEVHandle RHS = getSCEV(ExitCond->getOperand(1));
2527
2528   // Try to evaluate any dependencies out of the loop.
2529   LHS = getSCEVAtScope(LHS, L);
2530   RHS = getSCEVAtScope(RHS, L);
2531
2532   // At this point, we would like to compute how many iterations of the 
2533   // loop the predicate will return true for these inputs.
2534   if (LHS->isLoopInvariant(L) && !RHS->isLoopInvariant(L)) {
2535     // If there is a loop-invariant, force it into the RHS.
2536     std::swap(LHS, RHS);
2537     Cond = ICmpInst::getSwappedPredicate(Cond);
2538   }
2539
2540   // If we have a comparison of a chrec against a constant, try to use value
2541   // ranges to answer this query.
2542   if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
2543     if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
2544       if (AddRec->getLoop() == L) {
2545         // Form the constant range.
2546         ConstantRange CompRange(
2547             ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue()));
2548
2549         SCEVHandle Ret = AddRec->getNumIterationsInRange(CompRange, *this);
2550         if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
2551       }
2552
2553   switch (Cond) {
2554   case ICmpInst::ICMP_NE: {                     // while (X != Y)
2555     // Convert to: while (X-Y != 0)
2556     SCEVHandle TC = HowFarToZero(getMinusSCEV(LHS, RHS), L);
2557     if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2558     break;
2559   }
2560   case ICmpInst::ICMP_EQ: {
2561     // Convert to: while (X-Y == 0)           // while (X == Y)
2562     SCEVHandle TC = HowFarToNonZero(getMinusSCEV(LHS, RHS), L);
2563     if (!isa<SCEVCouldNotCompute>(TC)) return TC;
2564     break;
2565   }
2566   case ICmpInst::ICMP_SLT: {
2567     BackedgeTakenInfo BTI = HowManyLessThans(LHS, RHS, L, true);
2568     if (BTI.hasAnyInfo()) return BTI;
2569     break;
2570   }
2571   case ICmpInst::ICMP_SGT: {
2572     BackedgeTakenInfo BTI = HowManyLessThans(getNotSCEV(LHS),
2573                                              getNotSCEV(RHS), L, true);
2574     if (BTI.hasAnyInfo()) return BTI;
2575     break;
2576   }
2577   case ICmpInst::ICMP_ULT: {
2578     BackedgeTakenInfo BTI = HowManyLessThans(LHS, RHS, L, false);
2579     if (BTI.hasAnyInfo()) return BTI;
2580     break;
2581   }
2582   case ICmpInst::ICMP_UGT: {
2583     BackedgeTakenInfo BTI = HowManyLessThans(getNotSCEV(LHS),
2584                                              getNotSCEV(RHS), L, false);
2585     if (BTI.hasAnyInfo()) return BTI;
2586     break;
2587   }
2588   default:
2589 #if 0
2590     errs() << "ComputeBackedgeTakenCount ";
2591     if (ExitCond->getOperand(0)->getType()->isUnsigned())
2592       errs() << "[unsigned] ";
2593     errs() << *LHS << "   "
2594          << Instruction::getOpcodeName(Instruction::ICmp) 
2595          << "   " << *RHS << "\n";
2596 #endif
2597     break;
2598   }
2599   return
2600     ComputeBackedgeTakenCountExhaustively(L, ExitCond,
2601                                           ExitBr->getSuccessor(0) == ExitBlock);
2602 }
2603
2604 static ConstantInt *
2605 EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
2606                                 ScalarEvolution &SE) {
2607   SCEVHandle InVal = SE.getConstant(C);
2608   SCEVHandle Val = AddRec->evaluateAtIteration(InVal, SE);
2609   assert(isa<SCEVConstant>(Val) &&
2610          "Evaluation of SCEV at constant didn't fold correctly?");
2611   return cast<SCEVConstant>(Val)->getValue();
2612 }
2613
2614 /// GetAddressedElementFromGlobal - Given a global variable with an initializer
2615 /// and a GEP expression (missing the pointer index) indexing into it, return
2616 /// the addressed element of the initializer or null if the index expression is
2617 /// invalid.
2618 static Constant *
2619 GetAddressedElementFromGlobal(GlobalVariable *GV,
2620                               const std::vector<ConstantInt*> &Indices) {
2621   Constant *Init = GV->getInitializer();
2622   for (unsigned i = 0, e = Indices.size(); i != e; ++i) {
2623     uint64_t Idx = Indices[i]->getZExtValue();
2624     if (ConstantStruct *CS = dyn_cast<ConstantStruct>(Init)) {
2625       assert(Idx < CS->getNumOperands() && "Bad struct index!");
2626       Init = cast<Constant>(CS->getOperand(Idx));
2627     } else if (ConstantArray *CA = dyn_cast<ConstantArray>(Init)) {
2628       if (Idx >= CA->getNumOperands()) return 0;  // Bogus program
2629       Init = cast<Constant>(CA->getOperand(Idx));
2630     } else if (isa<ConstantAggregateZero>(Init)) {
2631       if (const StructType *STy = dyn_cast<StructType>(Init->getType())) {
2632         assert(Idx < STy->getNumElements() && "Bad struct index!");
2633         Init = Constant::getNullValue(STy->getElementType(Idx));
2634       } else if (const ArrayType *ATy = dyn_cast<ArrayType>(Init->getType())) {
2635         if (Idx >= ATy->getNumElements()) return 0;  // Bogus program
2636         Init = Constant::getNullValue(ATy->getElementType());
2637       } else {
2638         assert(0 && "Unknown constant aggregate type!");
2639       }
2640       return 0;
2641     } else {
2642       return 0; // Unknown initializer type
2643     }
2644   }
2645   return Init;
2646 }
2647
2648 /// ComputeLoadConstantCompareBackedgeTakenCount - Given an exit condition of
2649 /// 'icmp op load X, cst', try to see if we can compute the backedge
2650 /// execution count.
2651 SCEVHandle ScalarEvolution::
2652 ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI, Constant *RHS,
2653                                              const Loop *L,
2654                                              ICmpInst::Predicate predicate) {
2655   if (LI->isVolatile()) return UnknownValue;
2656
2657   // Check to see if the loaded pointer is a getelementptr of a global.
2658   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0));
2659   if (!GEP) return UnknownValue;
2660
2661   // Make sure that it is really a constant global we are gepping, with an
2662   // initializer, and make sure the first IDX is really 0.
2663   GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
2664   if (!GV || !GV->isConstant() || !GV->hasInitializer() ||
2665       GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) ||
2666       !cast<Constant>(GEP->getOperand(1))->isNullValue())
2667     return UnknownValue;
2668
2669   // Okay, we allow one non-constant index into the GEP instruction.
2670   Value *VarIdx = 0;
2671   std::vector<ConstantInt*> Indexes;
2672   unsigned VarIdxNum = 0;
2673   for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i)
2674     if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
2675       Indexes.push_back(CI);
2676     } else if (!isa<ConstantInt>(GEP->getOperand(i))) {
2677       if (VarIdx) return UnknownValue;  // Multiple non-constant idx's.
2678       VarIdx = GEP->getOperand(i);
2679       VarIdxNum = i-2;
2680       Indexes.push_back(0);
2681     }
2682
2683   // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant.
2684   // Check to see if X is a loop variant variable value now.
2685   SCEVHandle Idx = getSCEV(VarIdx);
2686   Idx = getSCEVAtScope(Idx, L);
2687
2688   // We can only recognize very limited forms of loop index expressions, in
2689   // particular, only affine AddRec's like {C1,+,C2}.
2690   const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx);
2691   if (!IdxExpr || !IdxExpr->isAffine() || IdxExpr->isLoopInvariant(L) ||
2692       !isa<SCEVConstant>(IdxExpr->getOperand(0)) ||
2693       !isa<SCEVConstant>(IdxExpr->getOperand(1)))
2694     return UnknownValue;
2695
2696   unsigned MaxSteps = MaxBruteForceIterations;
2697   for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) {
2698     ConstantInt *ItCst =
2699       ConstantInt::get(IdxExpr->getType(), IterationNum);
2700     ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this);
2701
2702     // Form the GEP offset.
2703     Indexes[VarIdxNum] = Val;
2704
2705     Constant *Result = GetAddressedElementFromGlobal(GV, Indexes);
2706     if (Result == 0) break;  // Cannot compute!
2707
2708     // Evaluate the condition for this iteration.
2709     Result = ConstantExpr::getICmp(predicate, Result, RHS);
2710     if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure
2711     if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
2712 #if 0
2713       errs() << "\n***\n*** Computed loop count " << *ItCst
2714              << "\n*** From global " << *GV << "*** BB: " << *L->getHeader()
2715              << "***\n";
2716 #endif
2717       ++NumArrayLenItCounts;
2718       return getConstant(ItCst);   // Found terminating iteration!
2719     }
2720   }
2721   return UnknownValue;
2722 }
2723
2724
2725 /// CanConstantFold - Return true if we can constant fold an instruction of the
2726 /// specified type, assuming that all operands were constants.
2727 static bool CanConstantFold(const Instruction *I) {
2728   if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
2729       isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I))
2730     return true;
2731
2732   if (const CallInst *CI = dyn_cast<CallInst>(I))
2733     if (const Function *F = CI->getCalledFunction())
2734       return canConstantFoldCallTo(F);
2735   return false;
2736 }
2737
2738 /// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
2739 /// in the loop that V is derived from.  We allow arbitrary operations along the
2740 /// way, but the operands of an operation must either be constants or a value
2741 /// derived from a constant PHI.  If this expression does not fit with these
2742 /// constraints, return null.
2743 static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
2744   // If this is not an instruction, or if this is an instruction outside of the
2745   // loop, it can't be derived from a loop PHI.
2746   Instruction *I = dyn_cast<Instruction>(V);
2747   if (I == 0 || !L->contains(I->getParent())) return 0;
2748
2749   if (PHINode *PN = dyn_cast<PHINode>(I)) {
2750     if (L->getHeader() == I->getParent())
2751       return PN;
2752     else
2753       // We don't currently keep track of the control flow needed to evaluate
2754       // PHIs, so we cannot handle PHIs inside of loops.
2755       return 0;
2756   }
2757
2758   // If we won't be able to constant fold this expression even if the operands
2759   // are constants, return early.
2760   if (!CanConstantFold(I)) return 0;
2761
2762   // Otherwise, we can evaluate this instruction if all of its operands are
2763   // constant or derived from a PHI node themselves.
2764   PHINode *PHI = 0;
2765   for (unsigned Op = 0, e = I->getNumOperands(); Op != e; ++Op)
2766     if (!(isa<Constant>(I->getOperand(Op)) ||
2767           isa<GlobalValue>(I->getOperand(Op)))) {
2768       PHINode *P = getConstantEvolvingPHI(I->getOperand(Op), L);
2769       if (P == 0) return 0;  // Not evolving from PHI
2770       if (PHI == 0)
2771         PHI = P;
2772       else if (PHI != P)
2773         return 0;  // Evolving from multiple different PHIs.
2774     }
2775
2776   // This is a expression evolving from a constant PHI!
2777   return PHI;
2778 }
2779
2780 /// EvaluateExpression - Given an expression that passes the
2781 /// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
2782 /// in the loop has the value PHIVal.  If we can't fold this expression for some
2783 /// reason, return null.
2784 static Constant *EvaluateExpression(Value *V, Constant *PHIVal) {
2785   if (isa<PHINode>(V)) return PHIVal;
2786   if (Constant *C = dyn_cast<Constant>(V)) return C;
2787   if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) return GV;
2788   Instruction *I = cast<Instruction>(V);
2789
2790   std::vector<Constant*> Operands;
2791   Operands.resize(I->getNumOperands());
2792
2793   for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
2794     Operands[i] = EvaluateExpression(I->getOperand(i), PHIVal);
2795     if (Operands[i] == 0) return 0;
2796   }
2797
2798   if (const CmpInst *CI = dyn_cast<CmpInst>(I))
2799     return ConstantFoldCompareInstOperands(CI->getPredicate(),
2800                                            &Operands[0], Operands.size());
2801   else
2802     return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
2803                                     &Operands[0], Operands.size());
2804 }
2805
2806 /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
2807 /// in the header of its containing loop, we know the loop executes a
2808 /// constant number of times, and the PHI node is just a recurrence
2809 /// involving constants, fold it.
2810 Constant *ScalarEvolution::
2811 getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& BEs, const Loop *L){
2812   std::map<PHINode*, Constant*>::iterator I =
2813     ConstantEvolutionLoopExitValue.find(PN);
2814   if (I != ConstantEvolutionLoopExitValue.end())
2815     return I->second;
2816
2817   if (BEs.ugt(APInt(BEs.getBitWidth(),MaxBruteForceIterations)))
2818     return ConstantEvolutionLoopExitValue[PN] = 0;  // Not going to evaluate it.
2819
2820   Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
2821
2822   // Since the loop is canonicalized, the PHI node must have two entries.  One
2823   // entry must be a constant (coming in from outside of the loop), and the
2824   // second must be derived from the same PHI.
2825   bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
2826   Constant *StartCST =
2827     dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
2828   if (StartCST == 0)
2829     return RetVal = 0;  // Must be a constant.
2830
2831   Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
2832   PHINode *PN2 = getConstantEvolvingPHI(BEValue, L);
2833   if (PN2 != PN)
2834     return RetVal = 0;  // Not derived from same PHI.
2835
2836   // Execute the loop symbolically to determine the exit value.
2837   if (BEs.getActiveBits() >= 32)
2838     return RetVal = 0; // More than 2^32-1 iterations?? Not doing it!
2839
2840   unsigned NumIterations = BEs.getZExtValue(); // must be in range
2841   unsigned IterationNum = 0;
2842   for (Constant *PHIVal = StartCST; ; ++IterationNum) {
2843     if (IterationNum == NumIterations)
2844       return RetVal = PHIVal;  // Got exit value!
2845
2846     // Compute the value of the PHI node for the next iteration.
2847     Constant *NextPHI = EvaluateExpression(BEValue, PHIVal);
2848     if (NextPHI == PHIVal)
2849       return RetVal = NextPHI;  // Stopped evolving!
2850     if (NextPHI == 0)
2851       return 0;        // Couldn't evaluate!
2852     PHIVal = NextPHI;
2853   }
2854 }
2855
2856 /// ComputeBackedgeTakenCountExhaustively - If the trip is known to execute a
2857 /// constant number of times (the condition evolves only from constants),
2858 /// try to evaluate a few iterations of the loop until we get the exit
2859 /// condition gets a value of ExitWhen (true or false).  If we cannot
2860 /// evaluate the trip count of the loop, return UnknownValue.
2861 SCEVHandle ScalarEvolution::
2862 ComputeBackedgeTakenCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) {
2863   PHINode *PN = getConstantEvolvingPHI(Cond, L);
2864   if (PN == 0) return UnknownValue;
2865
2866   // Since the loop is canonicalized, the PHI node must have two entries.  One
2867   // entry must be a constant (coming in from outside of the loop), and the
2868   // second must be derived from the same PHI.
2869   bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
2870   Constant *StartCST =
2871     dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
2872   if (StartCST == 0) return UnknownValue;  // Must be a constant.
2873
2874   Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
2875   PHINode *PN2 = getConstantEvolvingPHI(BEValue, L);
2876   if (PN2 != PN) return UnknownValue;  // Not derived from same PHI.
2877
2878   // Okay, we find a PHI node that defines the trip count of this loop.  Execute
2879   // the loop symbolically to determine when the condition gets a value of
2880   // "ExitWhen".
2881   unsigned IterationNum = 0;
2882   unsigned MaxIterations = MaxBruteForceIterations;   // Limit analysis.
2883   for (Constant *PHIVal = StartCST;
2884        IterationNum != MaxIterations; ++IterationNum) {
2885     ConstantInt *CondVal =
2886       dyn_cast_or_null<ConstantInt>(EvaluateExpression(Cond, PHIVal));
2887
2888     // Couldn't symbolically evaluate.
2889     if (!CondVal) return UnknownValue;
2890
2891     if (CondVal->getValue() == uint64_t(ExitWhen)) {
2892       ConstantEvolutionLoopExitValue[PN] = PHIVal;
2893       ++NumBruteForceTripCountsComputed;
2894       return getConstant(ConstantInt::get(Type::Int32Ty, IterationNum));
2895     }
2896
2897     // Compute the value of the PHI node for the next iteration.
2898     Constant *NextPHI = EvaluateExpression(BEValue, PHIVal);
2899     if (NextPHI == 0 || NextPHI == PHIVal)
2900       return UnknownValue;  // Couldn't evaluate or not making progress...
2901     PHIVal = NextPHI;
2902   }
2903
2904   // Too many iterations were needed to evaluate.
2905   return UnknownValue;
2906 }
2907
2908 /// getSCEVAtScope - Return a SCEV expression handle for the specified value
2909 /// at the specified scope in the program.  The L value specifies a loop
2910 /// nest to evaluate the expression at, where null is the top-level or a
2911 /// specified loop is immediately inside of the loop.
2912 ///
2913 /// This method can be used to compute the exit value for a variable defined
2914 /// in a loop by querying what the value will hold in the parent loop.
2915 ///
2916 /// In the case that a relevant loop exit value cannot be computed, the
2917 /// original value V is returned.
2918 SCEVHandle ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
2919   // FIXME: this should be turned into a virtual method on SCEV!
2920
2921   if (isa<SCEVConstant>(V)) return V;
2922
2923   // If this instruction is evolved from a constant-evolving PHI, compute the
2924   // exit value from the loop without using SCEVs.
2925   if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
2926     if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
2927       const Loop *LI = (*this->LI)[I->getParent()];
2928       if (LI && LI->getParentLoop() == L)  // Looking for loop exit value.
2929         if (PHINode *PN = dyn_cast<PHINode>(I))
2930           if (PN->getParent() == LI->getHeader()) {
2931             // Okay, there is no closed form solution for the PHI node.  Check
2932             // to see if the loop that contains it has a known backedge-taken
2933             // count.  If so, we may be able to force computation of the exit
2934             // value.
2935             SCEVHandle BackedgeTakenCount = getBackedgeTakenCount(LI);
2936             if (const SCEVConstant *BTCC =
2937                   dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
2938               // Okay, we know how many times the containing loop executes.  If
2939               // this is a constant evolving PHI node, get the final value at
2940               // the specified iteration number.
2941               Constant *RV = getConstantEvolutionLoopExitValue(PN,
2942                                                    BTCC->getValue()->getValue(),
2943                                                                LI);
2944               if (RV) return getUnknown(RV);
2945             }
2946           }
2947
2948       // Okay, this is an expression that we cannot symbolically evaluate
2949       // into a SCEV.  Check to see if it's possible to symbolically evaluate
2950       // the arguments into constants, and if so, try to constant propagate the
2951       // result.  This is particularly useful for computing loop exit values.
2952       if (CanConstantFold(I)) {
2953         // Check to see if we've folded this instruction at this loop before.
2954         std::map<const Loop *, Constant *> &Values = ValuesAtScopes[I];
2955         std::pair<std::map<const Loop *, Constant *>::iterator, bool> Pair =
2956           Values.insert(std::make_pair(L, static_cast<Constant *>(0)));
2957         if (!Pair.second)
2958           return Pair.first->second ? &*getUnknown(Pair.first->second) : V;
2959
2960         std::vector<Constant*> Operands;
2961         Operands.reserve(I->getNumOperands());
2962         for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
2963           Value *Op = I->getOperand(i);
2964           if (Constant *C = dyn_cast<Constant>(Op)) {
2965             Operands.push_back(C);
2966           } else {
2967             // If any of the operands is non-constant and if they are
2968             // non-integer and non-pointer, don't even try to analyze them
2969             // with scev techniques.
2970             if (!isSCEVable(Op->getType()))
2971               return V;
2972
2973             SCEVHandle OpV = getSCEVAtScope(getSCEV(Op), L);
2974             if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(OpV)) {
2975               Constant *C = SC->getValue();
2976               if (C->getType() != Op->getType())
2977                 C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
2978                                                                   Op->getType(),
2979                                                                   false),
2980                                           C, Op->getType());
2981               Operands.push_back(C);
2982             } else if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(OpV)) {
2983               if (Constant *C = dyn_cast<Constant>(SU->getValue())) {
2984                 if (C->getType() != Op->getType())
2985                   C =
2986                     ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
2987                                                                   Op->getType(),
2988                                                                   false),
2989                                           C, Op->getType());
2990                 Operands.push_back(C);
2991               } else
2992                 return V;
2993             } else {
2994               return V;
2995             }
2996           }
2997         }
2998         
2999         Constant *C;
3000         if (const CmpInst *CI = dyn_cast<CmpInst>(I))
3001           C = ConstantFoldCompareInstOperands(CI->getPredicate(),
3002                                               &Operands[0], Operands.size());
3003         else
3004           C = ConstantFoldInstOperands(I->getOpcode(), I->getType(),
3005                                        &Operands[0], Operands.size());
3006         Pair.first->second = C;
3007         return getUnknown(C);
3008       }
3009     }
3010
3011     // This is some other type of SCEVUnknown, just return it.
3012     return V;
3013   }
3014
3015   if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) {
3016     // Avoid performing the look-up in the common case where the specified
3017     // expression has no loop-variant portions.
3018     for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
3019       SCEVHandle OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
3020       if (OpAtScope != Comm->getOperand(i)) {
3021         // Okay, at least one of these operands is loop variant but might be
3022         // foldable.  Build a new instance of the folded commutative expression.
3023         std::vector<SCEVHandle> NewOps(Comm->op_begin(), Comm->op_begin()+i);
3024         NewOps.push_back(OpAtScope);
3025
3026         for (++i; i != e; ++i) {
3027           OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
3028           NewOps.push_back(OpAtScope);
3029         }
3030         if (isa<SCEVAddExpr>(Comm))
3031           return getAddExpr(NewOps);
3032         if (isa<SCEVMulExpr>(Comm))
3033           return getMulExpr(NewOps);
3034         if (isa<SCEVSMaxExpr>(Comm))
3035           return getSMaxExpr(NewOps);
3036         if (isa<SCEVUMaxExpr>(Comm))
3037           return getUMaxExpr(NewOps);
3038         assert(0 && "Unknown commutative SCEV type!");
3039       }
3040     }
3041     // If we got here, all operands are loop invariant.
3042     return Comm;
3043   }
3044
3045   if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
3046     SCEVHandle LHS = getSCEVAtScope(Div->getLHS(), L);
3047     SCEVHandle RHS = getSCEVAtScope(Div->getRHS(), L);
3048     if (LHS == Div->getLHS() && RHS == Div->getRHS())
3049       return Div;   // must be loop invariant
3050     return getUDivExpr(LHS, RHS);
3051   }
3052
3053   // If this is a loop recurrence for a loop that does not contain L, then we
3054   // are dealing with the final value computed by the loop.
3055   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
3056     if (!L || !AddRec->getLoop()->contains(L->getHeader())) {
3057       // To evaluate this recurrence, we need to know how many times the AddRec
3058       // loop iterates.  Compute this now.
3059       SCEVHandle BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
3060       if (BackedgeTakenCount == UnknownValue) return AddRec;
3061
3062       // Then, evaluate the AddRec.
3063       return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
3064     }
3065     return AddRec;
3066   }
3067
3068   if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) {
3069     SCEVHandle Op = getSCEVAtScope(Cast->getOperand(), L);
3070     if (Op == Cast->getOperand())
3071       return Cast;  // must be loop invariant
3072     return getZeroExtendExpr(Op, Cast->getType());
3073   }
3074
3075   if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) {
3076     SCEVHandle Op = getSCEVAtScope(Cast->getOperand(), L);
3077     if (Op == Cast->getOperand())
3078       return Cast;  // must be loop invariant
3079     return getSignExtendExpr(Op, Cast->getType());
3080   }
3081
3082   if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) {
3083     SCEVHandle Op = getSCEVAtScope(Cast->getOperand(), L);
3084     if (Op == Cast->getOperand())
3085       return Cast;  // must be loop invariant
3086     return getTruncateExpr(Op, Cast->getType());
3087   }
3088
3089   assert(0 && "Unknown SCEV type!");
3090   return 0;
3091 }
3092
3093 /// getSCEVAtScope - This is a convenience function which does
3094 /// getSCEVAtScope(getSCEV(V), L).
3095 SCEVHandle ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
3096   return getSCEVAtScope(getSCEV(V), L);
3097 }
3098
3099 /// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the
3100 /// following equation:
3101 ///
3102 ///     A * X = B (mod N)
3103 ///
3104 /// where N = 2^BW and BW is the common bit width of A and B. The signedness of
3105 /// A and B isn't important.
3106 ///
3107 /// If the equation does not have a solution, SCEVCouldNotCompute is returned.
3108 static SCEVHandle SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
3109                                                ScalarEvolution &SE) {
3110   uint32_t BW = A.getBitWidth();
3111   assert(BW == B.getBitWidth() && "Bit widths must be the same.");
3112   assert(A != 0 && "A must be non-zero.");
3113
3114   // 1. D = gcd(A, N)
3115   //
3116   // The gcd of A and N may have only one prime factor: 2. The number of
3117   // trailing zeros in A is its multiplicity
3118   uint32_t Mult2 = A.countTrailingZeros();
3119   // D = 2^Mult2
3120
3121   // 2. Check if B is divisible by D.
3122   //
3123   // B is divisible by D if and only if the multiplicity of prime factor 2 for B
3124   // is not less than multiplicity of this prime factor for D.
3125   if (B.countTrailingZeros() < Mult2)
3126     return SE.getCouldNotCompute();
3127
3128   // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
3129   // modulo (N / D).
3130   //
3131   // (N / D) may need BW+1 bits in its representation.  Hence, we'll use this
3132   // bit width during computations.
3133   APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
3134   APInt Mod(BW + 1, 0);
3135   Mod.set(BW - Mult2);  // Mod = N / D
3136   APInt I = AD.multiplicativeInverse(Mod);
3137
3138   // 4. Compute the minimum unsigned root of the equation:
3139   // I * (B / D) mod (N / D)
3140   APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod);
3141
3142   // The result is guaranteed to be less than 2^BW so we may truncate it to BW
3143   // bits.
3144   return SE.getConstant(Result.trunc(BW));
3145 }
3146
3147 /// SolveQuadraticEquation - Find the roots of the quadratic equation for the
3148 /// given quadratic chrec {L,+,M,+,N}.  This returns either the two roots (which
3149 /// might be the same) or two SCEVCouldNotCompute objects.
3150 ///
3151 static std::pair<SCEVHandle,SCEVHandle>
3152 SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
3153   assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
3154   const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
3155   const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
3156   const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
3157
3158   // We currently can only solve this if the coefficients are constants.
3159   if (!LC || !MC || !NC) {
3160     const SCEV *CNC = SE.getCouldNotCompute();
3161     return std::make_pair(CNC, CNC);
3162   }
3163
3164   uint32_t BitWidth = LC->getValue()->getValue().getBitWidth();
3165   const APInt &L = LC->getValue()->getValue();
3166   const APInt &M = MC->getValue()->getValue();
3167   const APInt &N = NC->getValue()->getValue();
3168   APInt Two(BitWidth, 2);
3169   APInt Four(BitWidth, 4);
3170
3171   { 
3172     using namespace APIntOps;
3173     const APInt& C = L;
3174     // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C
3175     // The B coefficient is M-N/2
3176     APInt B(M);
3177     B -= sdiv(N,Two);
3178
3179     // The A coefficient is N/2
3180     APInt A(N.sdiv(Two));
3181
3182     // Compute the B^2-4ac term.
3183     APInt SqrtTerm(B);
3184     SqrtTerm *= B;
3185     SqrtTerm -= Four * (A * C);
3186
3187     // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest
3188     // integer value or else APInt::sqrt() will assert.
3189     APInt SqrtVal(SqrtTerm.sqrt());
3190
3191     // Compute the two solutions for the quadratic formula. 
3192     // The divisions must be performed as signed divisions.
3193     APInt NegB(-B);
3194     APInt TwoA( A << 1 );
3195     if (TwoA.isMinValue()) {
3196       const SCEV *CNC = SE.getCouldNotCompute();
3197       return std::make_pair(CNC, CNC);
3198     }
3199
3200     ConstantInt *Solution1 = ConstantInt::get((NegB + SqrtVal).sdiv(TwoA));
3201     ConstantInt *Solution2 = ConstantInt::get((NegB - SqrtVal).sdiv(TwoA));
3202
3203     return std::make_pair(SE.getConstant(Solution1), 
3204                           SE.getConstant(Solution2));
3205     } // end APIntOps namespace
3206 }
3207
3208 /// HowFarToZero - Return the number of times a backedge comparing the specified
3209 /// value to zero will execute.  If not computable, return UnknownValue.
3210 SCEVHandle ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) {
3211   // If the value is a constant
3212   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
3213     // If the value is already zero, the branch will execute zero times.
3214     if (C->getValue()->isZero()) return C;
3215     return UnknownValue;  // Otherwise it will loop infinitely.
3216   }
3217
3218   const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V);
3219   if (!AddRec || AddRec->getLoop() != L)
3220     return UnknownValue;
3221
3222   if (AddRec->isAffine()) {
3223     // If this is an affine expression, the execution count of this branch is
3224     // the minimum unsigned root of the following equation:
3225     //
3226     //     Start + Step*N = 0 (mod 2^BW)
3227     //
3228     // equivalent to:
3229     //
3230     //             Step*N = -Start (mod 2^BW)
3231     //
3232     // where BW is the common bit width of Start and Step.
3233
3234     // Get the initial value for the loop.
3235     SCEVHandle Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
3236     SCEVHandle Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
3237
3238     if (const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step)) {
3239       // For now we handle only constant steps.
3240
3241       // First, handle unitary steps.
3242       if (StepC->getValue()->equalsInt(1))      // 1*N = -Start (mod 2^BW), so:
3243         return getNegativeSCEV(Start);       //   N = -Start (as unsigned)
3244       if (StepC->getValue()->isAllOnesValue())  // -1*N = -Start (mod 2^BW), so:
3245         return Start;                           //    N = Start (as unsigned)
3246
3247       // Then, try to solve the above equation provided that Start is constant.
3248       if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
3249         return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
3250                                             -StartC->getValue()->getValue(),
3251                                             *this);
3252     }
3253   } else if (AddRec->isQuadratic() && AddRec->getType()->isInteger()) {
3254     // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
3255     // the quadratic equation to solve it.
3256     std::pair<SCEVHandle,SCEVHandle> Roots = SolveQuadraticEquation(AddRec,
3257                                                                     *this);
3258     const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
3259     const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
3260     if (R1) {
3261 #if 0
3262       errs() << "HFTZ: " << *V << " - sol#1: " << *R1
3263              << "  sol#2: " << *R2 << "\n";
3264 #endif
3265       // Pick the smallest positive root value.
3266       if (ConstantInt *CB =
3267           dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, 
3268                                    R1->getValue(), R2->getValue()))) {
3269         if (CB->getZExtValue() == false)
3270           std::swap(R1, R2);   // R1 is the minimum root now.
3271
3272         // We can only use this value if the chrec ends up with an exact zero
3273         // value at this index.  When solving for "X*X != 5", for example, we
3274         // should not accept a root of 2.
3275         SCEVHandle Val = AddRec->evaluateAtIteration(R1, *this);
3276         if (Val->isZero())
3277           return R1;  // We found a quadratic root!
3278       }
3279     }
3280   }
3281
3282   return UnknownValue;
3283 }
3284
3285 /// HowFarToNonZero - Return the number of times a backedge checking the
3286 /// specified value for nonzero will execute.  If not computable, return
3287 /// UnknownValue
3288 SCEVHandle ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) {
3289   // Loops that look like: while (X == 0) are very strange indeed.  We don't
3290   // handle them yet except for the trivial case.  This could be expanded in the
3291   // future as needed.
3292
3293   // If the value is a constant, check to see if it is known to be non-zero
3294   // already.  If so, the backedge will execute zero times.
3295   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
3296     if (!C->getValue()->isNullValue())
3297       return getIntegerSCEV(0, C->getType());
3298     return UnknownValue;  // Otherwise it will loop infinitely.
3299   }
3300
3301   // We could implement others, but I really doubt anyone writes loops like
3302   // this, and if they did, they would already be constant folded.
3303   return UnknownValue;
3304 }
3305
3306 /// getLoopPredecessor - If the given loop's header has exactly one unique
3307 /// predecessor outside the loop, return it. Otherwise return null.
3308 ///
3309 BasicBlock *ScalarEvolution::getLoopPredecessor(const Loop *L) {
3310   BasicBlock *Header = L->getHeader();
3311   BasicBlock *Pred = 0;
3312   for (pred_iterator PI = pred_begin(Header), E = pred_end(Header);
3313        PI != E; ++PI)
3314     if (!L->contains(*PI)) {
3315       if (Pred && Pred != *PI) return 0; // Multiple predecessors.
3316       Pred = *PI;
3317     }
3318   return Pred;
3319 }
3320
3321 /// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB
3322 /// (which may not be an immediate predecessor) which has exactly one
3323 /// successor from which BB is reachable, or null if no such block is
3324 /// found.
3325 ///
3326 BasicBlock *
3327 ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
3328   // If the block has a unique predecessor, then there is no path from the
3329   // predecessor to the block that does not go through the direct edge
3330   // from the predecessor to the block.
3331   if (BasicBlock *Pred = BB->getSinglePredecessor())
3332     return Pred;
3333
3334   // A loop's header is defined to be a block that dominates the loop.
3335   // If the header has a unique predecessor outside the loop, it must be
3336   // a block that has exactly one successor that can reach the loop.
3337   if (Loop *L = LI->getLoopFor(BB))
3338     return getLoopPredecessor(L);
3339
3340   return 0;
3341 }
3342
3343 /// isLoopGuardedByCond - Test whether entry to the loop is protected by
3344 /// a conditional between LHS and RHS.  This is used to help avoid max
3345 /// expressions in loop trip counts.
3346 bool ScalarEvolution::isLoopGuardedByCond(const Loop *L,
3347                                           ICmpInst::Predicate Pred,
3348                                           const SCEV *LHS, const SCEV *RHS) {
3349   // Interpret a null as meaning no loop, where there is obviously no guard
3350   // (interprocedural conditions notwithstanding).
3351   if (!L) return false;
3352
3353   BasicBlock *Predecessor = getLoopPredecessor(L);
3354   BasicBlock *PredecessorDest = L->getHeader();
3355
3356   // Starting at the loop predecessor, climb up the predecessor chain, as long
3357   // as there are predecessors that can be found that have unique successors
3358   // leading to the original header.
3359   for (; Predecessor;
3360        PredecessorDest = Predecessor,
3361        Predecessor = getPredecessorWithUniqueSuccessorForBB(Predecessor)) {
3362
3363     BranchInst *LoopEntryPredicate =
3364       dyn_cast<BranchInst>(Predecessor->getTerminator());
3365     if (!LoopEntryPredicate ||
3366         LoopEntryPredicate->isUnconditional())
3367       continue;
3368
3369     ICmpInst *ICI = dyn_cast<ICmpInst>(LoopEntryPredicate->getCondition());
3370     if (!ICI) continue;
3371
3372     // Now that we found a conditional branch that dominates the loop, check to
3373     // see if it is the comparison we are looking for.
3374     Value *PreCondLHS = ICI->getOperand(0);
3375     Value *PreCondRHS = ICI->getOperand(1);
3376     ICmpInst::Predicate Cond;
3377     if (LoopEntryPredicate->getSuccessor(0) == PredecessorDest)
3378       Cond = ICI->getPredicate();
3379     else
3380       Cond = ICI->getInversePredicate();
3381
3382     if (Cond == Pred)
3383       ; // An exact match.
3384     else if (!ICmpInst::isTrueWhenEqual(Cond) && Pred == ICmpInst::ICMP_NE)
3385       ; // The actual condition is beyond sufficient.
3386     else
3387       // Check a few special cases.
3388       switch (Cond) {
3389       case ICmpInst::ICMP_UGT:
3390         if (Pred == ICmpInst::ICMP_ULT) {
3391           std::swap(PreCondLHS, PreCondRHS);
3392           Cond = ICmpInst::ICMP_ULT;
3393           break;
3394         }
3395         continue;
3396       case ICmpInst::ICMP_SGT:
3397         if (Pred == ICmpInst::ICMP_SLT) {
3398           std::swap(PreCondLHS, PreCondRHS);
3399           Cond = ICmpInst::ICMP_SLT;
3400           break;
3401         }
3402         continue;
3403       case ICmpInst::ICMP_NE:
3404         // Expressions like (x >u 0) are often canonicalized to (x != 0),
3405         // so check for this case by checking if the NE is comparing against
3406         // a minimum or maximum constant.
3407         if (!ICmpInst::isTrueWhenEqual(Pred))
3408           if (ConstantInt *CI = dyn_cast<ConstantInt>(PreCondRHS)) {
3409             const APInt &A = CI->getValue();
3410             switch (Pred) {
3411             case ICmpInst::ICMP_SLT:
3412               if (A.isMaxSignedValue()) break;
3413               continue;
3414             case ICmpInst::ICMP_SGT:
3415               if (A.isMinSignedValue()) break;
3416               continue;
3417             case ICmpInst::ICMP_ULT:
3418               if (A.isMaxValue()) break;
3419               continue;
3420             case ICmpInst::ICMP_UGT:
3421               if (A.isMinValue()) break;
3422               continue;
3423             default:
3424               continue;
3425             }
3426             Cond = ICmpInst::ICMP_NE;
3427             // NE is symmetric but the original comparison may not be. Swap
3428             // the operands if necessary so that they match below.
3429             if (isa<SCEVConstant>(LHS))
3430               std::swap(PreCondLHS, PreCondRHS);
3431             break;
3432           }
3433         continue;
3434       default:
3435         // We weren't able to reconcile the condition.
3436         continue;
3437       }
3438
3439     if (!PreCondLHS->getType()->isInteger()) continue;
3440
3441     SCEVHandle PreCondLHSSCEV = getSCEV(PreCondLHS);
3442     SCEVHandle PreCondRHSSCEV = getSCEV(PreCondRHS);
3443     if ((LHS == PreCondLHSSCEV && RHS == PreCondRHSSCEV) ||
3444         (LHS == getNotSCEV(PreCondRHSSCEV) &&
3445          RHS == getNotSCEV(PreCondLHSSCEV)))
3446       return true;
3447   }
3448
3449   return false;
3450 }
3451
3452 /// HowManyLessThans - Return the number of times a backedge containing the
3453 /// specified less-than comparison will execute.  If not computable, return
3454 /// UnknownValue.
3455 ScalarEvolution::BackedgeTakenInfo ScalarEvolution::
3456 HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
3457                  const Loop *L, bool isSigned) {
3458   // Only handle:  "ADDREC < LoopInvariant".
3459   if (!RHS->isLoopInvariant(L)) return UnknownValue;
3460
3461   const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS);
3462   if (!AddRec || AddRec->getLoop() != L)
3463     return UnknownValue;
3464
3465   if (AddRec->isAffine()) {
3466     // FORNOW: We only support unit strides.
3467     unsigned BitWidth = getTypeSizeInBits(AddRec->getType());
3468     SCEVHandle Step = AddRec->getStepRecurrence(*this);
3469     SCEVHandle NegOne = getIntegerSCEV(-1, AddRec->getType());
3470
3471     // TODO: handle non-constant strides.
3472     const SCEVConstant *CStep = dyn_cast<SCEVConstant>(Step);
3473     if (!CStep || CStep->isZero())
3474       return UnknownValue;
3475     if (CStep->isOne()) {
3476       // With unit stride, the iteration never steps past the limit value.
3477     } else if (CStep->getValue()->getValue().isStrictlyPositive()) {
3478       if (const SCEVConstant *CLimit = dyn_cast<SCEVConstant>(RHS)) {
3479         // Test whether a positive iteration iteration can step past the limit
3480         // value and past the maximum value for its type in a single step.
3481         if (isSigned) {
3482           APInt Max = APInt::getSignedMaxValue(BitWidth);
3483           if ((Max - CStep->getValue()->getValue())
3484                 .slt(CLimit->getValue()->getValue()))
3485             return UnknownValue;
3486         } else {
3487           APInt Max = APInt::getMaxValue(BitWidth);
3488           if ((Max - CStep->getValue()->getValue())
3489                 .ult(CLimit->getValue()->getValue()))
3490             return UnknownValue;
3491         }
3492       } else
3493         // TODO: handle non-constant limit values below.
3494         return UnknownValue;
3495     } else
3496       // TODO: handle negative strides below.
3497       return UnknownValue;
3498
3499     // We know the LHS is of the form {n,+,s} and the RHS is some loop-invariant
3500     // m.  So, we count the number of iterations in which {n,+,s} < m is true.
3501     // Note that we cannot simply return max(m-n,0)/s because it's not safe to
3502     // treat m-n as signed nor unsigned due to overflow possibility.
3503
3504     // First, we get the value of the LHS in the first iteration: n
3505     SCEVHandle Start = AddRec->getOperand(0);
3506
3507     // Determine the minimum constant start value.
3508     SCEVHandle MinStart = isa<SCEVConstant>(Start) ? Start :
3509       getConstant(isSigned ? APInt::getSignedMinValue(BitWidth) :
3510                              APInt::getMinValue(BitWidth));
3511
3512     // If we know that the condition is true in order to enter the loop,
3513     // then we know that it will run exactly (m-n)/s times. Otherwise, we
3514     // only know that it will execute (max(m,n)-n)/s times. In both cases,
3515     // the division must round up.
3516     SCEVHandle End = RHS;
3517     if (!isLoopGuardedByCond(L,
3518                              isSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT,
3519                              getMinusSCEV(Start, Step), RHS))
3520       End = isSigned ? getSMaxExpr(RHS, Start)
3521                      : getUMaxExpr(RHS, Start);
3522
3523     // Determine the maximum constant end value.
3524     SCEVHandle MaxEnd = isa<SCEVConstant>(End) ? End :
3525       getConstant(isSigned ? APInt::getSignedMaxValue(BitWidth) :
3526                              APInt::getMaxValue(BitWidth));
3527
3528     // Finally, we subtract these two values and divide, rounding up, to get
3529     // the number of times the backedge is executed.
3530     SCEVHandle BECount = getUDivExpr(getAddExpr(getMinusSCEV(End, Start),
3531                                                 getAddExpr(Step, NegOne)),
3532                                      Step);
3533
3534     // The maximum backedge count is similar, except using the minimum start
3535     // value and the maximum end value.
3536     SCEVHandle MaxBECount = getUDivExpr(getAddExpr(getMinusSCEV(MaxEnd,
3537                                                                 MinStart),
3538                                                    getAddExpr(Step, NegOne)),
3539                                         Step);
3540
3541     return BackedgeTakenInfo(BECount, MaxBECount);
3542   }
3543
3544   return UnknownValue;
3545 }
3546
3547 /// getNumIterationsInRange - Return the number of iterations of this loop that
3548 /// produce values in the specified constant range.  Another way of looking at
3549 /// this is that it returns the first iteration number where the value is not in
3550 /// the condition, thus computing the exit count. If the iteration count can't
3551 /// be computed, an instance of SCEVCouldNotCompute is returned.
3552 SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
3553                                                    ScalarEvolution &SE) const {
3554   if (Range.isFullSet())  // Infinite loop.
3555     return SE.getCouldNotCompute();
3556
3557   // If the start is a non-zero constant, shift the range to simplify things.
3558   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
3559     if (!SC->getValue()->isZero()) {
3560       std::vector<SCEVHandle> Operands(op_begin(), op_end());
3561       Operands[0] = SE.getIntegerSCEV(0, SC->getType());
3562       SCEVHandle Shifted = SE.getAddRecExpr(Operands, getLoop());
3563       if (const SCEVAddRecExpr *ShiftedAddRec =
3564             dyn_cast<SCEVAddRecExpr>(Shifted))
3565         return ShiftedAddRec->getNumIterationsInRange(
3566                            Range.subtract(SC->getValue()->getValue()), SE);
3567       // This is strange and shouldn't happen.
3568       return SE.getCouldNotCompute();
3569     }
3570
3571   // The only time we can solve this is when we have all constant indices.
3572   // Otherwise, we cannot determine the overflow conditions.
3573   for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
3574     if (!isa<SCEVConstant>(getOperand(i)))
3575       return SE.getCouldNotCompute();
3576
3577
3578   // Okay at this point we know that all elements of the chrec are constants and
3579   // that the start element is zero.
3580
3581   // First check to see if the range contains zero.  If not, the first
3582   // iteration exits.
3583   unsigned BitWidth = SE.getTypeSizeInBits(getType());
3584   if (!Range.contains(APInt(BitWidth, 0)))
3585     return SE.getConstant(ConstantInt::get(getType(),0));
3586
3587   if (isAffine()) {
3588     // If this is an affine expression then we have this situation:
3589     //   Solve {0,+,A} in Range  ===  Ax in Range
3590
3591     // We know that zero is in the range.  If A is positive then we know that
3592     // the upper value of the range must be the first possible exit value.
3593     // If A is negative then the lower of the range is the last possible loop
3594     // value.  Also note that we already checked for a full range.
3595     APInt One(BitWidth,1);
3596     APInt A     = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
3597     APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
3598
3599     // The exit value should be (End+A)/A.
3600     APInt ExitVal = (End + A).udiv(A);
3601     ConstantInt *ExitValue = ConstantInt::get(ExitVal);
3602
3603     // Evaluate at the exit value.  If we really did fall out of the valid
3604     // range, then we computed our trip count, otherwise wrap around or other
3605     // things must have happened.
3606     ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
3607     if (Range.contains(Val->getValue()))
3608       return SE.getCouldNotCompute();  // Something strange happened
3609
3610     // Ensure that the previous value is in the range.  This is a sanity check.
3611     assert(Range.contains(
3612            EvaluateConstantChrecAtConstant(this, 
3613            ConstantInt::get(ExitVal - One), SE)->getValue()) &&
3614            "Linear scev computation is off in a bad way!");
3615     return SE.getConstant(ExitValue);
3616   } else if (isQuadratic()) {
3617     // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the
3618     // quadratic equation to solve it.  To do this, we must frame our problem in
3619     // terms of figuring out when zero is crossed, instead of when
3620     // Range.getUpper() is crossed.
3621     std::vector<SCEVHandle> NewOps(op_begin(), op_end());
3622     NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper()));
3623     SCEVHandle NewAddRec = SE.getAddRecExpr(NewOps, getLoop());
3624
3625     // Next, solve the constructed addrec
3626     std::pair<SCEVHandle,SCEVHandle> Roots =
3627       SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
3628     const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
3629     const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
3630     if (R1) {
3631       // Pick the smallest positive root value.
3632       if (ConstantInt *CB =
3633           dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, 
3634                                    R1->getValue(), R2->getValue()))) {
3635         if (CB->getZExtValue() == false)
3636           std::swap(R1, R2);   // R1 is the minimum root now.
3637
3638         // Make sure the root is not off by one.  The returned iteration should
3639         // not be in the range, but the previous one should be.  When solving
3640         // for "X*X < 5", for example, we should not return a root of 2.
3641         ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this,
3642                                                              R1->getValue(),
3643                                                              SE);
3644         if (Range.contains(R1Val->getValue())) {
3645           // The next iteration must be out of the range...
3646           ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()+1);
3647
3648           R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
3649           if (!Range.contains(R1Val->getValue()))
3650             return SE.getConstant(NextVal);
3651           return SE.getCouldNotCompute();  // Something strange happened
3652         }
3653
3654         // If R1 was not in the range, then it is a good return value.  Make
3655         // sure that R1-1 WAS in the range though, just in case.
3656         ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()-1);
3657         R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
3658         if (Range.contains(R1Val->getValue()))
3659           return R1;
3660         return SE.getCouldNotCompute();  // Something strange happened
3661       }
3662     }
3663   }
3664
3665   return SE.getCouldNotCompute();
3666 }
3667
3668
3669
3670 //===----------------------------------------------------------------------===//
3671 //                   SCEVCallbackVH Class Implementation
3672 //===----------------------------------------------------------------------===//
3673
3674 void ScalarEvolution::SCEVCallbackVH::deleted() {
3675   assert(SE && "SCEVCallbackVH called with a non-null ScalarEvolution!");
3676   if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
3677     SE->ConstantEvolutionLoopExitValue.erase(PN);
3678   if (Instruction *I = dyn_cast<Instruction>(getValPtr()))
3679     SE->ValuesAtScopes.erase(I);
3680   SE->Scalars.erase(getValPtr());
3681   // this now dangles!
3682 }
3683
3684 void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *) {
3685   assert(SE && "SCEVCallbackVH called with a non-null ScalarEvolution!");
3686
3687   // Forget all the expressions associated with users of the old value,
3688   // so that future queries will recompute the expressions using the new
3689   // value.
3690   SmallVector<User *, 16> Worklist;
3691   Value *Old = getValPtr();
3692   bool DeleteOld = false;
3693   for (Value::use_iterator UI = Old->use_begin(), UE = Old->use_end();
3694        UI != UE; ++UI)
3695     Worklist.push_back(*UI);
3696   while (!Worklist.empty()) {
3697     User *U = Worklist.pop_back_val();
3698     // Deleting the Old value will cause this to dangle. Postpone
3699     // that until everything else is done.
3700     if (U == Old) {
3701       DeleteOld = true;
3702       continue;
3703     }
3704     if (PHINode *PN = dyn_cast<PHINode>(U))
3705       SE->ConstantEvolutionLoopExitValue.erase(PN);
3706     if (Instruction *I = dyn_cast<Instruction>(U))
3707       SE->ValuesAtScopes.erase(I);
3708     if (SE->Scalars.erase(U))
3709       for (Value::use_iterator UI = U->use_begin(), UE = U->use_end();
3710            UI != UE; ++UI)
3711         Worklist.push_back(*UI);
3712   }
3713   if (DeleteOld) {
3714     if (PHINode *PN = dyn_cast<PHINode>(Old))
3715       SE->ConstantEvolutionLoopExitValue.erase(PN);
3716     if (Instruction *I = dyn_cast<Instruction>(Old))
3717       SE->ValuesAtScopes.erase(I);
3718     SE->Scalars.erase(Old);
3719     // this now dangles!
3720   }
3721   // this may dangle!
3722 }
3723
3724 ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
3725   : CallbackVH(V), SE(se) {}
3726
3727 //===----------------------------------------------------------------------===//
3728 //                   ScalarEvolution Class Implementation
3729 //===----------------------------------------------------------------------===//
3730
3731 ScalarEvolution::ScalarEvolution()
3732   : FunctionPass(&ID), UnknownValue(new SCEVCouldNotCompute()) {
3733 }
3734
3735 bool ScalarEvolution::runOnFunction(Function &F) {
3736   this->F = &F;
3737   LI = &getAnalysis<LoopInfo>();
3738   TD = getAnalysisIfAvailable<TargetData>();
3739   return false;
3740 }
3741
3742 void ScalarEvolution::releaseMemory() {
3743   Scalars.clear();
3744   BackedgeTakenCounts.clear();
3745   ConstantEvolutionLoopExitValue.clear();
3746   ValuesAtScopes.clear();
3747 }
3748
3749 void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
3750   AU.setPreservesAll();
3751   AU.addRequiredTransitive<LoopInfo>();
3752 }
3753
3754 bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
3755   return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
3756 }
3757
3758 static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
3759                           const Loop *L) {
3760   // Print all inner loops first
3761   for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
3762     PrintLoopInfo(OS, SE, *I);
3763
3764   OS << "Loop " << L->getHeader()->getName() << ": ";
3765
3766   SmallVector<BasicBlock*, 8> ExitBlocks;
3767   L->getExitBlocks(ExitBlocks);
3768   if (ExitBlocks.size() != 1)
3769     OS << "<multiple exits> ";
3770
3771   if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
3772     OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L);
3773   } else {
3774     OS << "Unpredictable backedge-taken count. ";
3775   }
3776
3777   OS << "\n";
3778 }
3779
3780 void ScalarEvolution::print(raw_ostream &OS, const Module* ) const {
3781   // ScalarEvolution's implementaiton of the print method is to print
3782   // out SCEV values of all instructions that are interesting. Doing
3783   // this potentially causes it to create new SCEV objects though,
3784   // which technically conflicts with the const qualifier. This isn't
3785   // observable from outside the class though (the hasSCEV function
3786   // notwithstanding), so casting away the const isn't dangerous.
3787   ScalarEvolution &SE = *const_cast<ScalarEvolution*>(this);
3788
3789   OS << "Classifying expressions for: " << F->getName() << "\n";
3790   for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
3791     if (isSCEVable(I->getType())) {
3792       OS << *I;
3793       OS << "  -->  ";
3794       SCEVHandle SV = SE.getSCEV(&*I);
3795       SV->print(OS);
3796       OS << "\t\t";
3797
3798       if (const Loop *L = LI->getLoopFor((*I).getParent())) {
3799         OS << "Exits: ";
3800         SCEVHandle ExitValue = SE.getSCEVAtScope(&*I, L->getParentLoop());
3801         if (!ExitValue->isLoopInvariant(L)) {
3802           OS << "<<Unknown>>";
3803         } else {
3804           OS << *ExitValue;
3805         }
3806       }
3807
3808       OS << "\n";
3809     }
3810
3811   OS << "Determining loop execution counts for: " << F->getName() << "\n";
3812   for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I)
3813     PrintLoopInfo(OS, &SE, *I);
3814 }
3815
3816 void ScalarEvolution::print(std::ostream &o, const Module *M) const {
3817   raw_os_ostream OS(o);
3818   print(OS, M);
3819 }