Adding dllimport, dllexport and external weak linkage types.
[oota-llvm.git] / lib / Transforms / Scalar / PredicateSimplifier.cpp
1 //===-- PredicateSimplifier.cpp - Path Sensitive Simplifier -----------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by Nick Lewycky and is distributed under the
6 // University of Illinois Open Source License. See LICENSE.TXT for details.
7 //
8 //===------------------------------------------------------------------===//
9 //
10 // Path-sensitive optimizer. In a branch where x == y, replace uses of
11 // x with y. Permits further optimization, such as the elimination of
12 // the unreachable call:
13 //
14 // void test(int *p, int *q)
15 // {
16 //   if (p != q)
17 //     return;
18 // 
19 //   if (*p != *q)
20 //     foo(); // unreachable
21 // }
22 //
23 //===------------------------------------------------------------------===//
24 //
25 // This optimization works by substituting %q for %p when protected by a
26 // conditional that assures us of that fact. Properties are stored as
27 // relationships between two values.
28 //
29 //===------------------------------------------------------------------===//
30
31 #define DEBUG_TYPE "predsimplify"
32 #include "llvm/Transforms/Scalar.h"
33 #include "llvm/Constants.h"
34 #include "llvm/Instructions.h"
35 #include "llvm/Pass.h"
36 #include "llvm/ADT/Statistic.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/Analysis/Dominators.h"
39 #include "llvm/Support/CFG.h"
40 #include "llvm/Support/Debug.h"
41 #include <iostream>
42 using namespace llvm;
43
44 typedef DominatorTree::Node DTNodeType;
45
46 namespace {
47   Statistic<>
48   NumVarsReplaced("predsimplify", "Number of argument substitutions");
49   Statistic<>
50   NumInstruction("predsimplify", "Number of instructions removed");
51   Statistic<>
52   NumSwitchCases("predsimplify", "Number of switch cases removed");
53   Statistic<>
54   NumBranches("predsimplify", "Number of branches made unconditional");
55
56   /// Returns true if V1 is a better choice than V2. Note that it is
57   /// not a total ordering.
58   struct compare {
59     bool operator()(Value *V1, Value *V2) const {
60       if (isa<Constant>(V1)) {
61         if (!isa<Constant>(V2)) {
62           return true;
63         }
64       } else if (isa<Argument>(V1)) {
65         if (!isa<Constant>(V2) && !isa<Argument>(V2)) {
66           return true;
67         }
68       }
69       if (User *U = dyn_cast<User>(V2)) {
70         for (User::const_op_iterator I = U->op_begin(), E = U->op_end();
71              I != E; ++I) {
72           if (*I == V1) {
73             return true;
74           }
75         }
76       }
77       return false;
78     }
79   };
80
81   /// Used for choosing the canonical Value in a synonym set.
82   /// Leaves the better choice in V1.
83   static void order(Value *&V1, Value *&V2) {
84     static compare c;
85     if (c(V2, V1))
86       std::swap(V1, V2);
87   }
88
89   /// Similar to EquivalenceClasses, this stores the set of equivalent
90   /// types. Beyond EquivalenceClasses, it allows the user to specify
91   /// which element will act as leader through a StrictWeakOrdering
92   /// function.
93   template<typename ElemTy, typename StrictWeak>
94   class VISIBILITY_HIDDEN Synonyms {
95     std::map<ElemTy, unsigned> mapping;
96     std::vector<ElemTy> leaders;
97     StrictWeak swo;
98
99   public:
100     typedef unsigned iterator;
101     typedef const unsigned const_iterator;
102
103     // Inspection
104
105     bool empty() const {
106       return leaders.empty();
107     }
108
109     unsigned countLeaders() const {
110       return leaders.size();
111     }
112
113     iterator findLeader(ElemTy e) {
114       typename std::map<ElemTy, unsigned>::iterator MI = mapping.find(e);
115       if (MI == mapping.end()) return 0;
116
117       return MI->second;
118     }
119
120     const_iterator findLeader(ElemTy e) const {
121       typename std::map<ElemTy, unsigned>::const_iterator MI =
122           mapping.find(e);
123       if (MI == mapping.end()) return 0;
124
125       return MI->second;
126     }
127
128     ElemTy &getLeader(iterator I) {
129       assert(I != 0 && "Element zero is out of range.");
130       assert(I <= leaders.size() && "Invalid iterator.");
131       return leaders[I-1];
132     }
133
134     const ElemTy &getLeader(const_iterator I) const {
135       assert(I != 0 && "Element zero is out of range.");
136       return leaders[I-1];
137     }
138
139 #ifdef DEBUG
140     void debug(std::ostream &os) const {
141       std::set<Value *> Unique;
142       for (unsigned i = 1, e = leaders.size()+1; i != e; ++i) {
143         Unique.insert(getLeader(i));
144         os << i << ". " << *getLeader(i) << ": [";
145         for (std::map<Value *, unsigned>::const_iterator
146              I = mapping.begin(), E = mapping.end(); I != E; ++I) {
147           if ((*I).second == i && (*I).first != leaders[i-1]) {
148             os << *(*I).first << "  ";
149           }
150         }
151         os << "]\n";
152       }
153       assert(Unique.size() == leaders.size() && "Duplicate leaders.");
154
155       for (typename std::map<ElemTy, unsigned>::const_iterator
156            I = mapping.begin(), E = mapping.end(); I != E; ++I) {
157         assert(I->second != 0 && "Zero iterator in mapping.");
158         assert(I->second <= leaders.size() &&
159                "Invalid iterator found in mapping.");
160       }
161     }
162 #endif
163
164     // Mutators
165
166     /// Combine two sets referring to the same element, inserting the
167     /// elements as needed. Returns a valid iterator iff two already
168     /// existing disjoint synonym sets were combined. The iterator
169     /// points to the removed element.
170     iterator unionSets(ElemTy E1, ElemTy E2) {
171       if (swo(E2, E1)) std::swap(E1, E2);
172
173       iterator I1 = findLeader(E1),
174                I2 = findLeader(E2);
175
176       if (!I1 && !I2) { // neither entry is in yet
177         leaders.push_back(E1);
178         I1 = leaders.size();
179         mapping[E1] = I1;
180         mapping[E2] = I1;
181         return 0;
182       }
183
184       if (!I1 && I2) {
185         mapping[E1] = I2;
186         std::swap(getLeader(I2), E1);
187         return 0;
188       }
189
190       if (I1 && !I2) {
191         mapping[E2] = I1;
192         return 0;
193       }
194
195       if (I1 == I2) return 0;
196
197       // This is the case where we have two sets, [%a1, %a2, %a3] and
198       // [%p1, %p2, %p3] and someone says that %a2 == %p3. We need to
199       // combine the two synsets.
200
201       if (I1 > I2) --I1;
202
203       for (std::map<Value *, unsigned>::iterator I = mapping.begin(),
204            E = mapping.end(); I != E; ++I) {
205         if (I->second == I2) I->second = I1;
206         else if (I->second > I2) --I->second;
207       }
208
209       leaders.erase(leaders.begin() + I2 - 1);
210
211       return I2;
212     }
213
214     /// Returns an iterator pointing to the synonym set containing
215     /// element e. If none exists, a new one is created and returned.
216     iterator findOrInsert(ElemTy e) {
217       iterator I = findLeader(e);
218       if (I) return I;
219
220       leaders.push_back(e);
221       I = leaders.size();
222       mapping[e] = I;
223       return I;
224     }
225   };
226
227   /// Represents the set of equivalent Value*s and provides insertion
228   /// and fast lookup. Also stores the set of inequality relationships.
229   class PropertySet {
230     struct Property;
231   public:
232     class Synonyms<Value *, compare> union_find;
233
234     typedef std::vector<Property>::iterator       PropertyIterator;
235     typedef std::vector<Property>::const_iterator ConstPropertyIterator;
236     typedef Synonyms<Value *, compare>::iterator  SynonymIterator;
237
238     enum Ops {
239       EQ,
240       NE
241     };
242
243     Value *canonicalize(Value *V) const {
244       Value *C = lookup(V);
245       return C ? C : V;
246     }
247
248     Value *lookup(Value *V) const {
249       Synonyms<Value *, compare>::iterator SI = union_find.findLeader(V);
250       if (!SI) return NULL;
251       return union_find.getLeader(SI);
252     }
253
254     bool empty() const {
255       return union_find.empty();
256     }
257
258     void addEqual(Value *V1, Value *V2) {
259       // If %x = 0. and %y = -0., seteq %x, %y is true, but
260       // copysign(%x) is not the same as copysign(%y).
261       if (V1->getType()->isFloatingPoint()) return;
262
263       order(V1, V2);
264       if (isa<Constant>(V2)) return; // refuse to set false == true.
265
266       DEBUG(std::cerr << "equal: " << *V1 << " and " << *V2 << "\n");
267       SynonymIterator deleted = union_find.unionSets(V1, V2);
268       if (deleted) {
269         SynonymIterator replacement = union_find.findLeader(V1);
270         // Move Properties
271         for (PropertyIterator I = Properties.begin(), E = Properties.end();
272              I != E; ++I) {
273           if (I->I1 == deleted) I->I1 = replacement;
274           else if (I->I1 > deleted) --I->I1;
275           if (I->I2 == deleted) I->I2 = replacement;
276           else if (I->I2 > deleted) --I->I2;
277         }
278       }
279       addImpliedProperties(EQ, V1, V2);
280     }
281
282     void addNotEqual(Value *V1, Value *V2) {
283       // If %x = NAN then seteq %x, %x is false.
284       if (V1->getType()->isFloatingPoint()) return;
285
286       // For example, %x = setne int 0, 0 causes "0 != 0".
287       if (isa<Constant>(V1) && isa<Constant>(V2)) return;
288
289       DEBUG(std::cerr << "not equal: " << *V1 << " and " << *V2 << "\n");
290       if (findProperty(NE, V1, V2) != Properties.end())
291         return; // found.
292
293       // Add the property.
294       SynonymIterator I1 = union_find.findOrInsert(V1),
295                       I2 = union_find.findOrInsert(V2);
296
297       // Technically this means that the block is unreachable.
298       if (I1 == I2) return;
299
300       Properties.push_back(Property(NE, I1, I2));
301       addImpliedProperties(NE, V1, V2);
302     }
303
304     PropertyIterator findProperty(Ops Opcode, Value *V1, Value *V2) {
305       assert(Opcode != EQ && "Can't findProperty on EQ."
306              "Use the lookup method instead.");
307
308       SynonymIterator I1 = union_find.findLeader(V1),
309                       I2 = union_find.findLeader(V2);
310       if (!I1 || !I2) return Properties.end();
311
312       return
313       find(Properties.begin(), Properties.end(), Property(Opcode, I1, I2));
314     }
315
316     ConstPropertyIterator
317     findProperty(Ops Opcode, Value *V1, Value *V2) const {
318       assert(Opcode != EQ && "Can't findProperty on EQ."
319              "Use the lookup method instead.");
320
321       SynonymIterator I1 = union_find.findLeader(V1),
322                       I2 = union_find.findLeader(V2);
323       if (!I1 || !I2) return Properties.end();
324
325       return
326       find(Properties.begin(), Properties.end(), Property(Opcode, I1, I2));
327     }
328
329   private:
330     // Represents Head OP [Tail1, Tail2, ...]
331     // For example: %x != %a, %x != %b.
332     struct VISIBILITY_HIDDEN Property {
333       typedef Synonyms<Value *, compare>::iterator Iter;
334
335       Property(Ops opcode, Iter i1, Iter i2)
336         : Opcode(opcode), I1(i1), I2(i2)
337       { assert(opcode != EQ && "Equality belongs in the synonym set, "
338                                "not a property."); }
339
340       bool operator==(const Property &P) const {
341         return (Opcode == P.Opcode) &&
342                ((I1 == P.I1 && I2 == P.I2) ||
343                 (I1 == P.I2 && I2 == P.I1));
344       }
345
346       Ops Opcode;
347       Iter I1, I2;
348     };
349
350     void add(Ops Opcode, Value *V1, Value *V2, bool invert) {
351       switch (Opcode) {
352         case EQ:
353           if (invert) addNotEqual(V1, V2);
354           else        addEqual(V1, V2);
355           break;
356         case NE:
357           if (invert) addEqual(V1, V2);
358           else        addNotEqual(V1, V2);
359           break;
360         default:
361           assert(0 && "Unknown property opcode.");
362       }
363     }
364
365     // Finds the properties implied by an equivalence and adds them too.
366     // Example: ("seteq %a, %b", true,  EQ) --> (%a, %b, EQ)
367     //          ("seteq %a, %b", false, EQ) --> (%a, %b, NE)
368     void addImpliedProperties(Ops Opcode, Value *V1, Value *V2) {
369       order(V1, V2);
370
371       if (BinaryOperator *BO = dyn_cast<BinaryOperator>(V2)) {
372         switch (BO->getOpcode()) {
373         case Instruction::SetEQ:
374           if (V1 == ConstantBool::True)
375             add(Opcode, BO->getOperand(0), BO->getOperand(1), false);
376           if (V1 == ConstantBool::False)
377             add(Opcode, BO->getOperand(0), BO->getOperand(1), true);
378           break;
379         case Instruction::SetNE:
380           if (V1 == ConstantBool::True)
381             add(Opcode, BO->getOperand(0), BO->getOperand(1), true);
382           if (V1 == ConstantBool::False)
383             add(Opcode, BO->getOperand(0), BO->getOperand(1), false);
384           break;
385         case Instruction::SetLT:
386         case Instruction::SetGT:
387           if (V1 == ConstantBool::True)
388             add(Opcode, BO->getOperand(0), BO->getOperand(1), true);
389           break;
390         case Instruction::SetLE:
391         case Instruction::SetGE:
392           if (V1 == ConstantBool::False)
393             add(Opcode, BO->getOperand(0), BO->getOperand(1), true);
394           break;
395         case Instruction::And:
396           if (V1 == ConstantBool::True) {
397             add(Opcode, ConstantBool::True, BO->getOperand(0), false);
398             add(Opcode, ConstantBool::True, BO->getOperand(1), false);
399           }
400           break;
401         case Instruction::Or:
402           if (V1 == ConstantBool::False) {
403             add(Opcode, ConstantBool::False, BO->getOperand(0), false);
404             add(Opcode, ConstantBool::False, BO->getOperand(1), false);
405           }
406           break;
407         case Instruction::Xor:
408           if (V1 == ConstantBool::True) {
409             if (BO->getOperand(0) == ConstantBool::True)
410               add(Opcode, ConstantBool::False, BO->getOperand(1), false);
411             if (BO->getOperand(1) == ConstantBool::True)
412               add(Opcode, ConstantBool::False, BO->getOperand(0), false);
413           }
414           if (V1 == ConstantBool::False) {
415             if (BO->getOperand(0) == ConstantBool::True)
416               add(Opcode, ConstantBool::True, BO->getOperand(1), false);
417             if (BO->getOperand(1) == ConstantBool::True)
418               add(Opcode, ConstantBool::True, BO->getOperand(0), false);
419           }
420           break;
421         default:
422           break;
423         }
424       } else if (SelectInst *SI = dyn_cast<SelectInst>(V2)) {
425         if (Opcode != EQ && Opcode != NE) return;
426
427         ConstantBool *True  = (Opcode==EQ) ? ConstantBool::True
428                                            : ConstantBool::False,
429                      *False = (Opcode==EQ) ? ConstantBool::False
430                                            : ConstantBool::True;
431
432         if (V1 == SI->getTrueValue())
433           addEqual(SI->getCondition(), True);
434         else if (V1 == SI->getFalseValue())
435           addEqual(SI->getCondition(), False);
436         else if (Opcode == EQ)
437           assert("Result of select not equal to either value.");
438       }
439     }
440
441   public:
442 #ifdef DEBUG
443     void debug(std::ostream &os) const {
444       static const char *OpcodeTable[] = { "EQ", "NE" };
445
446       unsigned int size = union_find.countLeaders();
447
448       union_find.debug(os);
449       for (std::vector<Property>::const_iterator I = Properties.begin(),
450            E = Properties.end(); I != E; ++I) {
451         os << (*I).I1 << " " << OpcodeTable[(*I).Opcode] << " "
452            << (*I).I2 << "\n";
453         assert((*I).I1 <= size && "Invalid property.");
454         assert((*I).I2 <= size && "Invalid property.");
455       }
456       os << "\n";
457     }
458 #endif
459
460     std::vector<Property> Properties;
461   };
462
463   /// PredicateSimplifier - This class is a simplifier that replaces
464   /// one equivalent variable with another. It also tracks what
465   /// can't be equal and will solve setcc instructions when possible.
466   class PredicateSimplifier : public FunctionPass {
467   public:
468     bool runOnFunction(Function &F);
469     virtual void getAnalysisUsage(AnalysisUsage &AU) const;
470
471   private:
472     // Try to replace the Use of the instruction with something simpler.
473     Value *resolve(SetCondInst *SCI, const PropertySet &);
474     Value *resolve(BinaryOperator *BO, const PropertySet &);
475     Value *resolve(SelectInst *SI, const PropertySet &);
476     Value *resolve(Value *V, const PropertySet &);
477
478     // Used by terminator instructions to proceed from the current basic
479     // block to the next. Verifies that "current" dominates "next",
480     // then calls visitBasicBlock.
481     void proceedToSuccessor(PropertySet &CurrentPS, PropertySet &NextPS,
482                             DTNodeType *Current, DTNodeType *Next);
483     void proceedToSuccessor(PropertySet &CurrentPS,
484                             DTNodeType *Current, DTNodeType *Next);
485
486     // Visits each instruction in the basic block.
487     void visitBasicBlock(DTNodeType *DTNode, PropertySet &KnownProperties);
488
489     // Tries to simplify each Instruction and add new properties to
490     // the PropertySet. Returns true if it erase the instruction.
491     void visitInstruction(Instruction *I, DTNodeType *, PropertySet &);
492     // For each instruction, add the properties to KnownProperties.
493
494     void visit(TerminatorInst *TI, DTNodeType *, PropertySet &);
495     void visit(BranchInst *BI, DTNodeType *, PropertySet &);
496     void visit(SwitchInst *SI, DTNodeType *, PropertySet);
497     void visit(LoadInst *LI, DTNodeType *, PropertySet &);
498     void visit(StoreInst *SI, DTNodeType *, PropertySet &);
499     void visit(BinaryOperator *BO, DTNodeType *, PropertySet &);
500
501     DominatorTree *DT;
502     bool modified;
503   };
504
505   RegisterPass<PredicateSimplifier> X("predsimplify",
506                                       "Predicate Simplifier");
507 }
508
509 FunctionPass *llvm::createPredicateSimplifierPass() {
510   return new PredicateSimplifier();
511 }
512
513 bool PredicateSimplifier::runOnFunction(Function &F) {
514   DT = &getAnalysis<DominatorTree>();
515
516   modified = false;
517   PropertySet KnownProperties;
518   visitBasicBlock(DT->getRootNode(), KnownProperties);
519   return modified;
520 }
521
522 void PredicateSimplifier::getAnalysisUsage(AnalysisUsage &AU) const {
523   AU.addRequired<DominatorTree>();
524 }
525
526 // resolve catches cases addProperty won't because it wasn't used as a
527 // condition in the branch, and that visit won't, because the instruction
528 // was defined outside of the scope that the properties apply to.
529 Value *PredicateSimplifier::resolve(SetCondInst *SCI,
530                                     const PropertySet &KP) {
531   // Attempt to resolve the SetCondInst to a boolean.
532
533   Value *SCI0 = resolve(SCI->getOperand(0), KP),
534         *SCI1 = resolve(SCI->getOperand(1), KP);
535
536   ConstantIntegral *CI1 = dyn_cast<ConstantIntegral>(SCI0),
537                    *CI2 = dyn_cast<ConstantIntegral>(SCI1);
538
539   if (!CI1 || !CI2) {
540     PropertySet::ConstPropertyIterator NE =
541         KP.findProperty(PropertySet::NE, SCI0, SCI1);
542
543     if (NE != KP.Properties.end()) {
544       switch (SCI->getOpcode()) {
545         case Instruction::SetEQ:
546           return ConstantBool::False;
547         case Instruction::SetNE:
548           return ConstantBool::True;
549         case Instruction::SetLE:
550         case Instruction::SetGE:
551         case Instruction::SetLT:
552         case Instruction::SetGT:
553           break;
554         default:
555           assert(0 && "Unknown opcode in SetCondInst.");
556           break;
557       }
558     }
559     return SCI;
560   }
561
562   switch(SCI->getOpcode()) {
563     case Instruction::SetLE:
564     case Instruction::SetGE:
565     case Instruction::SetEQ:
566       if (CI1->getRawValue() == CI2->getRawValue())
567         return ConstantBool::True;
568       else
569         return ConstantBool::False;
570     case Instruction::SetLT:
571     case Instruction::SetGT:
572     case Instruction::SetNE:
573       if (CI1->getRawValue() == CI2->getRawValue())
574         return ConstantBool::False;
575       else
576         return ConstantBool::True;
577     default:
578       assert(0 && "Unknown opcode in SetContInst.");
579       break;
580   }
581 }
582
583 Value *PredicateSimplifier::resolve(BinaryOperator *BO,
584                                     const PropertySet &KP) {
585   if (SetCondInst *SCI = dyn_cast<SetCondInst>(BO))
586     return resolve(SCI, KP);
587
588   Value *lhs = resolve(BO->getOperand(0), KP),
589         *rhs = resolve(BO->getOperand(1), KP);
590   ConstantIntegral *CI1 = dyn_cast<ConstantIntegral>(lhs);
591   ConstantIntegral *CI2 = dyn_cast<ConstantIntegral>(rhs);
592
593   if (!CI1 || !CI2) return BO;
594
595   Value *V = ConstantExpr::get(BO->getOpcode(), CI1, CI2);
596   if (V) return V;
597   return BO;
598 }
599
600 Value *PredicateSimplifier::resolve(SelectInst *SI, const PropertySet &KP) {
601   Value *Condition = resolve(SI->getCondition(), KP);
602   if (Condition == ConstantBool::True)
603     return resolve(SI->getTrueValue(), KP);
604   else if (Condition == ConstantBool::False)
605     return resolve(SI->getFalseValue(), KP);
606   return SI;
607 }
608
609 Value *PredicateSimplifier::resolve(Value *V, const PropertySet &KP) {
610   if (isa<Constant>(V) || isa<BasicBlock>(V) || KP.empty()) return V;
611
612   V = KP.canonicalize(V);
613
614   DEBUG(std::cerr << "peering into " << *V << "\n");
615
616   if (BinaryOperator *BO = dyn_cast<BinaryOperator>(V))
617     return resolve(BO, KP);
618   else if (SelectInst *SI = dyn_cast<SelectInst>(V))
619     return resolve(SI, KP);
620
621   return V;
622 }
623
624 void PredicateSimplifier::visitBasicBlock(DTNodeType *DTNode,
625                                           PropertySet &KnownProperties) {
626   BasicBlock *BB = DTNode->getBlock();
627   for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) {
628     visitInstruction(I++, DTNode, KnownProperties);
629   }
630 }
631
632 void PredicateSimplifier::visitInstruction(Instruction *I,
633                                            DTNodeType *DTNode,
634                                            PropertySet &KnownProperties) {
635
636   DEBUG(std::cerr << "Considering instruction " << *I << "\n");
637   DEBUG(KnownProperties.debug(std::cerr));
638
639   // Try to replace the whole instruction.
640   Value *V = resolve(I, KnownProperties);
641   assert(V->getType() == I->getType() && "Instruction type mutated!");
642   if (V != I) {
643     modified = true;
644     ++NumInstruction;
645     I->replaceAllUsesWith(V);
646     I->eraseFromParent();
647     return;
648   }
649
650   // Try to substitute operands.
651   for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
652     Value *Oper = I->getOperand(i);
653     Value *V = resolve(Oper, KnownProperties);
654     assert(V->getType() == Oper->getType() && "Operand type mutated!");
655     if (V != Oper) {
656       modified = true;
657       ++NumVarsReplaced;
658       DEBUG(std::cerr << "resolving " << *I);
659       I->setOperand(i, V);
660       DEBUG(std::cerr << "into " << *I);
661     }
662   }
663
664   if (TerminatorInst *TI = dyn_cast<TerminatorInst>(I))
665     visit(TI, DTNode, KnownProperties);
666   else if (LoadInst *LI = dyn_cast<LoadInst>(I))
667     visit(LI, DTNode, KnownProperties);
668   else if (StoreInst *SI = dyn_cast<StoreInst>(I))
669     visit(SI, DTNode, KnownProperties);
670   else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(I))
671     visit(BO, DTNode, KnownProperties);
672 }
673
674 void PredicateSimplifier::proceedToSuccessor(PropertySet &CurrentPS,
675                                              PropertySet &NextPS,
676                                              DTNodeType *Current,
677                                              DTNodeType *Next) {
678   if (Next->getBlock()->getSinglePredecessor() == Current->getBlock())
679     proceedToSuccessor(NextPS, Current, Next);
680   else
681     proceedToSuccessor(CurrentPS, Current, Next);
682 }
683
684 void PredicateSimplifier::proceedToSuccessor(PropertySet &KP,
685                                              DTNodeType *Current,
686                                              DTNodeType *Next) {
687   if (Current->properlyDominates(Next))
688     visitBasicBlock(Next, KP);
689 }
690
691 void PredicateSimplifier::visit(TerminatorInst *TI, DTNodeType *Node,
692                                 PropertySet &KP) {
693   if (BranchInst *BI = dyn_cast<BranchInst>(TI)) {
694     visit(BI, Node, KP);
695     return;
696   }
697   if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {
698     visit(SI, Node, KP);
699     return;
700   }
701
702   for (unsigned i = 0, E = TI->getNumSuccessors(); i != E; ++i) {
703     BasicBlock *BB = TI->getSuccessor(i);
704     PropertySet KPcopy(KP);
705     proceedToSuccessor(KPcopy, Node, DT->getNode(TI->getSuccessor(i)));
706   }
707 }
708
709 void PredicateSimplifier::visit(BranchInst *BI, DTNodeType *Node,
710                                 PropertySet &KP) {
711   if (BI->isUnconditional()) {
712     proceedToSuccessor(KP, Node, DT->getNode(BI->getSuccessor(0)));
713     return;
714   }
715
716   Value *Condition = BI->getCondition();
717
718   BasicBlock *TrueDest  = BI->getSuccessor(0),
719              *FalseDest = BI->getSuccessor(1);
720
721   if (Condition == ConstantBool::True) {
722     FalseDest->removePredecessor(BI->getParent());
723     BI->setUnconditionalDest(TrueDest);
724     modified = true;
725     ++NumBranches;
726     proceedToSuccessor(KP, Node, DT->getNode(TrueDest));
727     return;
728   } else if (Condition == ConstantBool::False) {
729     TrueDest->removePredecessor(BI->getParent());
730     BI->setUnconditionalDest(FalseDest);
731     modified = true;
732     ++NumBranches;
733     proceedToSuccessor(KP, Node, DT->getNode(FalseDest));
734     return;
735   }
736
737   PropertySet TrueProperties(KP), FalseProperties(KP);
738   DEBUG(std::cerr << "true set:\n");
739   TrueProperties.addEqual(ConstantBool::True,   Condition);
740   DEBUG(TrueProperties.debug(std::cerr));
741   DEBUG(std::cerr << "false set:\n");
742   FalseProperties.addEqual(ConstantBool::False, Condition);
743   DEBUG(FalseProperties.debug(std::cerr));
744
745   PropertySet KPcopy(KP);
746   proceedToSuccessor(KP,     TrueProperties,  Node, DT->getNode(TrueDest));
747   proceedToSuccessor(KPcopy, FalseProperties, Node, DT->getNode(FalseDest));
748 }
749
750 void PredicateSimplifier::visit(SwitchInst *SI, DTNodeType *DTNode,
751                                 PropertySet KP) {
752   Value *Condition = SI->getCondition();
753   assert(Condition == KP.canonicalize(Condition) &&
754          "Instruction wasn't already canonicalized?");
755
756   // If there's an NEProperty covering this SwitchInst, we may be able to
757   // eliminate one of the cases.
758   for (PropertySet::ConstPropertyIterator I = KP.Properties.begin(),
759        E = KP.Properties.end(); I != E; ++I) {
760     if (I->Opcode != PropertySet::NE) continue;
761     Value *V1 = KP.union_find.getLeader(I->I1),
762           *V2 = KP.union_find.getLeader(I->I2);
763
764     // Find a Property with a ConstantInt on one side and our
765     // Condition on the other.
766     ConstantInt *CI = NULL;
767     if (V1 == Condition)
768       CI = dyn_cast<ConstantInt>(V2);
769     else if (V2 == Condition)
770       CI = dyn_cast<ConstantInt>(V1);
771
772     if (!CI) continue;
773
774     unsigned i = SI->findCaseValue(CI);
775     if (i != 0) { // zero is reserved for the default case.
776       SI->getSuccessor(i)->removePredecessor(SI->getParent());
777       SI->removeCase(i);
778       modified = true;
779       ++NumSwitchCases;
780     }
781   }
782
783   // Set the EQProperty in each of the cases BBs,
784   // and the NEProperties in the default BB.
785   PropertySet DefaultProperties(KP);
786
787   DTNodeType *Node        = DT->getNode(SI->getParent()),
788              *DefaultNode = DT->getNode(SI->getSuccessor(0));
789   if (!Node->dominates(DefaultNode)) DefaultNode = NULL;
790
791   for (unsigned I = 1, E = SI->getNumCases(); I < E; ++I) {
792     ConstantInt *CI = SI->getCaseValue(I);
793
794     BasicBlock *SuccBB = SI->getSuccessor(I);
795     PropertySet copy(KP);
796     if (SuccBB->getSinglePredecessor()) {
797       PropertySet NewProperties(KP);
798       NewProperties.addEqual(Condition, CI);
799       proceedToSuccessor(copy, NewProperties, DTNode, DT->getNode(SuccBB));
800     } else
801       proceedToSuccessor(copy, DTNode, DT->getNode(SuccBB));
802
803     if (DefaultNode)
804       DefaultProperties.addNotEqual(Condition, CI);
805   }
806
807   if (DefaultNode)
808     proceedToSuccessor(DefaultProperties, DTNode, DefaultNode);
809 }
810
811 void PredicateSimplifier::visit(LoadInst *LI, DTNodeType *,
812                                 PropertySet &KP) {
813   Value *Ptr = LI->getPointerOperand();
814   KP.addNotEqual(Constant::getNullValue(Ptr->getType()), Ptr);
815 }
816
817 void PredicateSimplifier::visit(StoreInst *SI, DTNodeType *,
818                                 PropertySet &KP) {
819   Value *Ptr = SI->getPointerOperand();
820   KP.addNotEqual(Constant::getNullValue(Ptr->getType()), Ptr);
821 }
822
823 void PredicateSimplifier::visit(BinaryOperator *BO, DTNodeType *,
824                                 PropertySet &KP) {
825   Instruction::BinaryOps ops = BO->getOpcode();
826
827   switch (ops) {
828     case Instruction::Div:
829     case Instruction::Rem: {
830       Value *Divisor = BO->getOperand(1);
831       KP.addNotEqual(Constant::getNullValue(Divisor->getType()), Divisor);
832       break;
833     }
834     default:
835       break;
836   }
837 }