Skip the linear search if the answer is already known.
[oota-llvm.git] / lib / Transforms / Scalar / PredicateSimplifier.cpp
1 //===-- PredicateSimplifier.cpp - Path Sensitive Simplifier -----------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by Nick Lewycky and is distributed under the
6 // University of Illinois Open Source License. See LICENSE.TXT for details.
7 //
8 //===------------------------------------------------------------------===//
9 //
10 // Path-sensitive optimizer. In a branch where x == y, replace uses of
11 // x with y. Permits further optimization, such as the elimination of
12 // the unreachable call:
13 //
14 // void test(int *p, int *q)
15 // {
16 //   if (p != q)
17 //     return;
18 // 
19 //   if (*p != *q)
20 //     foo(); // unreachable
21 // }
22 //
23 //===------------------------------------------------------------------===//
24 //
25 // This optimization works by substituting %q for %p when protected by a
26 // conditional that assures us of that fact. Properties are stored as
27 // relationships between two values.
28 //
29 //===------------------------------------------------------------------===//
30
31 #define DEBUG_TYPE "predsimplify"
32 #include "llvm/Transforms/Scalar.h"
33 #include "llvm/Constants.h"
34 #include "llvm/Instructions.h"
35 #include "llvm/Pass.h"
36 #include "llvm/ADT/Statistic.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/Analysis/Dominators.h"
39 #include "llvm/Support/CFG.h"
40 #include "llvm/Support/Debug.h"
41 #include <iostream>
42 using namespace llvm;
43
44 typedef DominatorTree::Node DTNodeType;
45
46 namespace {
47   Statistic<>
48   NumVarsReplaced("predsimplify", "Number of argument substitutions");
49   Statistic<>
50   NumInstruction("predsimplify", "Number of instructions removed");
51   Statistic<>
52   NumSwitchCases("predsimplify", "Number of switch cases removed");
53   Statistic<>
54   NumBranches("predsimplify", "Number of branches made unconditional");
55
56   /// Returns true if V1 is a better choice than V2. Note that it is
57   /// not a total ordering.
58   struct compare {
59     bool operator()(Value *V1, Value *V2) const {
60       if (isa<Constant>(V2)) {
61         if (!isa<Constant>(V1)) {
62           return true;
63         }
64       } else if (isa<Argument>(V2)) {
65         if (!isa<Constant>(V1) && !isa<Argument>(V1)) {
66           return true;
67         }
68       }
69       if (User *U1 = dyn_cast<User>(V1)) {
70         for (User::const_op_iterator I = U1->op_begin(), E = U1->op_end();
71              I != E; ++I) {
72           if (*I == V2) {
73             return true;
74           }
75         }
76       }
77       return false;
78     }
79   };
80
81   /// Used for choosing the canonical Value in a synonym set.
82   /// Leaves the better one in V1.
83   static void order(Value *&V1, Value *&V2) {
84     static compare c;
85     if (c(V1, V2))
86       std::swap(V1, V2);
87   }
88
89   /// Similar to EquivalenceClasses, this stores the set of equivalent
90   /// types. Beyond EquivalenceClasses, it allows the user to specify
91   /// which element will act as leader through a StrictWeakOrdering
92   /// function.
93   template<typename ElemTy, typename StrictWeak>
94   class VISIBILITY_HIDDEN Synonyms {
95     std::map<ElemTy, unsigned> mapping;
96     std::vector<ElemTy> leaders;
97     StrictWeak swo;
98
99   public:
100     typedef unsigned iterator;
101     typedef const unsigned const_iterator;
102
103     // Inspection
104
105     bool empty() const {
106       return leaders.empty();
107     }
108
109     iterator findLeader(ElemTy e) {
110       typename std::map<ElemTy, unsigned>::iterator MI = mapping.find(e);
111       if (MI == mapping.end()) return 0;
112
113       return MI->second;
114     }
115
116     const_iterator findLeader(ElemTy e) const {
117       typename std::map<ElemTy, unsigned>::const_iterator MI =
118           mapping.find(e);
119       if (MI == mapping.end()) return 0;
120
121       return MI->second;
122     }
123
124     ElemTy &getLeader(iterator I) {
125       assert(I != 0 && "Element zero is out of range.");
126       return leaders[I-1];
127     }
128
129     const ElemTy &getLeader(const_iterator I) const {
130       assert(I != 0 && "Element zero is out of range.");
131       return leaders[I-1];
132     }
133
134 #ifdef DEBUG
135     void debug(std::ostream &os) const {
136       for (unsigned i = 1, e = leaders.size()+1; i != e; ++i) {
137         os << i << ". " << *leaders[i-1] << ": [";
138         for (std::map<Value *, unsigned>::const_iterator
139              I = mapping.begin(), E = mapping.end(); I != E; ++I) {
140           if ((*I).second == i && (*I).first != leaders[i-1]) {
141             os << *(*I).first << "  ";
142           }
143         }
144         os << "]\n";
145       }
146     }
147 #endif
148
149     // Mutators
150
151     /// Combine two sets referring to the same element, inserting the
152     /// elements as needed. Returns a valid iterator iff two already
153     /// existing disjoint synonym sets were combined. The iterator
154     /// points to the removed element.
155     iterator unionSets(ElemTy E1, ElemTy E2) {
156       if (swo(E1, E2)) std::swap(E1, E2);
157
158       iterator I1 = findLeader(E1);
159       iterator I2 = findLeader(E2);
160
161       if (!I1 && !I2) { // neither entry is in yet
162         leaders.push_back(E1);
163         I1 = leaders.size();
164         mapping[E1] = I1;
165         mapping[E2] = I1;
166         return false;
167       }
168
169       if (!I1 && I2) {
170         mapping[E1] = I2;
171         return false;
172       }
173
174       if (I1 && !I2) {
175         mapping[E2] = I1;
176         return false;
177       }
178
179       // This is the case where we have two sets, [%a1, %a2, %a3] and
180       // [%p1, %p2, %p3] and someone says that %a2 == %p3. We need to
181       // combine the two synsets.
182
183       for (std::map<Value *, unsigned>::iterator I = mapping.begin(),
184            E = mapping.end(); I != E; ++I) {
185         if (I->second == I2) I->second = I1;
186         else if (I->second > I2) --I->second;
187       }
188
189       leaders.erase(leaders.begin() + I2 - 1);
190
191       return true;
192     }
193
194     /// Returns an iterator pointing to the synonym set containing
195     /// element e. If none exists, a new one is created and returned.
196     iterator findOrInsert(ElemTy e) {
197       iterator I = findLeader(e);
198       if (I) return I;
199
200       leaders.push_back(e);
201       I = leaders.size();
202       mapping[e] = I;
203       return I;
204     }
205   };
206
207   /// Represents the set of equivalent Value*s and provides insertion
208   /// and fast lookup. Also stores the set of inequality relationships.
209   class PropertySet {
210     struct Property;
211   public:
212     class Synonyms<Value *, compare> union_find;
213
214     typedef std::vector<Property>::iterator       PropertyIterator;
215     typedef std::vector<Property>::const_iterator ConstPropertyIterator;
216     typedef Synonyms<Value *, compare>::iterator  SynonymIterator;
217
218     enum Ops {
219       EQ,
220       NE
221     };
222
223     Value *canonicalize(Value *V) const {
224       Value *C = lookup(V);
225       return C ? C : V;
226     }
227
228     Value *lookup(Value *V) const {
229       Synonyms<Value *, compare>::iterator SI = union_find.findLeader(V);
230       if (!SI) return NULL;
231       return union_find.getLeader(SI);
232     }
233
234     bool empty() const {
235       return union_find.empty();
236     }
237
238     void addEqual(Value *V1, Value *V2) {
239       // If %x = 0. and %y = -0., seteq %x, %y is true, but
240       // copysign(%x) is not the same as copysign(%y).
241       if (V2->getType()->isFloatingPoint()) return;
242
243       order(V1, V2);
244       if (isa<Constant>(V2)) return; // refuse to set false == true.
245
246       DEBUG(std::cerr << "equal: " << *V1 << " and " << *V2 << "\n");
247       SynonymIterator deleted = union_find.unionSets(V1, V2);
248       if (deleted) {
249         SynonymIterator replacement = union_find.findLeader(V1);
250         // Move Properties
251         for (PropertyIterator I = Properties.begin(), E = Properties.end();
252              I != E; ++I) {
253           if (I->I1 == deleted) I->I1 = replacement;
254           else if (I->I1 > deleted) --I->I1;
255           if (I->I2 == deleted) I->I2 = replacement;
256           else if (I->I2 > deleted) --I->I2;
257         }
258       }
259       addImpliedProperties(EQ, V1, V2);
260     }
261
262     void addNotEqual(Value *V1, Value *V2) {
263       // If %x = NAN then seteq %x, %x is false.
264       if (V2->getType()->isFloatingPoint()) return;
265
266       DEBUG(std::cerr << "not equal: " << *V1 << " and " << *V2 << "\n");
267       if (findProperty(NE, V1, V2) != Properties.end())
268         return; // found.
269
270       // Add the property.
271       SynonymIterator I1 = union_find.findOrInsert(V1),
272                       I2 = union_find.findOrInsert(V2);
273       Properties.push_back(Property(NE, I1, I2));
274       addImpliedProperties(NE, V1, V2);
275     }
276
277     PropertyIterator findProperty(Ops Opcode, Value *V1, Value *V2) {
278       assert(Opcode != EQ && "Can't findProperty on EQ."
279              "Use the lookup method instead.");
280
281       SynonymIterator I1 = union_find.findLeader(V1),
282                       I2 = union_find.findLeader(V2);
283       if (!I1 || !I2) return Properties.end();
284
285       return
286       find(Properties.begin(), Properties.end(), Property(Opcode, I1, I2));
287     }
288
289     ConstPropertyIterator
290     findProperty(Ops Opcode, Value *V1, Value *V2) const {
291       assert(Opcode != EQ && "Can't findProperty on EQ."
292              "Use the lookup method instead.");
293
294       SynonymIterator I1 = union_find.findLeader(V1),
295                       I2 = union_find.findLeader(V2);
296       if (!I1 || !I2) return Properties.end();
297
298       return
299       find(Properties.begin(), Properties.end(), Property(Opcode, I1, I2));
300     }
301
302   private:
303     // Represents Head OP [Tail1, Tail2, ...]
304     // For example: %x != %a, %x != %b.
305     struct VISIBILITY_HIDDEN Property {
306       typedef Synonyms<Value *, compare>::iterator Iter;
307
308       Property(Ops opcode, Iter i1, Iter i2)
309         : Opcode(opcode), I1(i1), I2(i2)
310       { assert(opcode != EQ && "Equality belongs in the synonym set, "
311                                "not a property."); }
312
313       bool operator==(const Property &P) const {
314         return (Opcode == P.Opcode) &&
315                ((I1 == P.I1 && I2 == P.I2) ||
316                 (I1 == P.I2 && I2 == P.I1));
317       }
318
319       Ops Opcode;
320       Iter I1, I2;
321     };
322
323     void add(Ops Opcode, Value *V1, Value *V2, bool invert) {
324       switch (Opcode) {
325         case EQ:
326           if (invert) addNotEqual(V1, V2);
327           else        addEqual(V1, V2);
328           break;
329         case NE:
330           if (invert) addEqual(V1, V2);
331           else        addNotEqual(V1, V2);
332           break;
333         default:
334           assert(0 && "Unknown property opcode.");
335       }
336     }
337
338     // Finds the properties implied by an equivalence and adds them too.
339     // Example: ("seteq %a, %b", true,  EQ) --> (%a, %b, EQ)
340     //          ("seteq %a, %b", false, EQ) --> (%a, %b, NE)
341     void addImpliedProperties(Ops Opcode, Value *V1, Value *V2) {
342       order(V1, V2);
343
344       if (BinaryOperator *BO = dyn_cast<BinaryOperator>(V2)) {
345         switch (BO->getOpcode()) {
346         case Instruction::SetEQ:
347           if (V1 == ConstantBool::True)
348             add(Opcode, BO->getOperand(0), BO->getOperand(1), false);
349           if (V1 == ConstantBool::False)
350             add(Opcode, BO->getOperand(0), BO->getOperand(1), true);
351           break;
352         case Instruction::SetNE:
353           if (V1 == ConstantBool::True)
354             add(Opcode, BO->getOperand(0), BO->getOperand(1), true);
355           if (V1 == ConstantBool::False)
356             add(Opcode, BO->getOperand(0), BO->getOperand(1), false);
357           break;
358         case Instruction::SetLT:
359         case Instruction::SetGT:
360           if (V1 == ConstantBool::True)
361             add(Opcode, BO->getOperand(0), BO->getOperand(1), true);
362           break;
363         case Instruction::SetLE:
364         case Instruction::SetGE:
365           if (V1 == ConstantBool::False)
366             add(Opcode, BO->getOperand(0), BO->getOperand(1), true);
367           break;
368         case Instruction::And:
369           if (V1 == ConstantBool::True) {
370             add(Opcode, ConstantBool::True, BO->getOperand(0), false);
371             add(Opcode, ConstantBool::True, BO->getOperand(1), false);
372           }
373           break;
374         case Instruction::Or:
375           if (V1 == ConstantBool::False) {
376             add(Opcode, ConstantBool::False, BO->getOperand(0), false);
377             add(Opcode, ConstantBool::False, BO->getOperand(1), false);
378           }
379           break;
380         case Instruction::Xor:
381           if (V1 == ConstantBool::True) {
382             if (BO->getOperand(0) == ConstantBool::True)
383               add(Opcode, ConstantBool::False, BO->getOperand(1), false);
384             if (BO->getOperand(1) == ConstantBool::True)
385               add(Opcode, ConstantBool::False, BO->getOperand(0), false);
386           }
387           if (V1 == ConstantBool::False) {
388             if (BO->getOperand(0) == ConstantBool::True)
389               add(Opcode, ConstantBool::True, BO->getOperand(1), false);
390             if (BO->getOperand(1) == ConstantBool::True)
391               add(Opcode, ConstantBool::True, BO->getOperand(0), false);
392           }
393           break;
394         default:
395           break;
396         }
397       } else if (SelectInst *SI = dyn_cast<SelectInst>(V2)) {
398         if (Opcode != EQ && Opcode != NE) return;
399
400         ConstantBool *True  = (Opcode==EQ) ? ConstantBool::True
401                                            : ConstantBool::False,
402                      *False = (Opcode==EQ) ? ConstantBool::False
403                                            : ConstantBool::True;
404
405         if (V1 == SI->getTrueValue())
406           addEqual(SI->getCondition(), True);
407         else if (V1 == SI->getFalseValue())
408           addEqual(SI->getCondition(), False);
409         else if (Opcode == EQ)
410           assert("Result of select not equal to either value.");
411       }
412     }
413
414   public:
415 #ifdef DEBUG
416     void debug(std::ostream &os) const {
417       static const char *OpcodeTable[] = { "EQ", "NE" };
418
419       union_find.debug(os);
420       for (std::vector<Property>::const_iterator I = Properties.begin(),
421            E = Properties.end(); I != E; ++I) {
422         os << (*I).I1 << " " << OpcodeTable[(*I).Opcode] << " "
423            << (*I).I2 << "\n";
424       }
425       os << "\n";
426     }
427 #endif
428
429     std::vector<Property> Properties;
430   };
431
432   /// PredicateSimplifier - This class is a simplifier that replaces
433   /// one equivalent variable with another. It also tracks what
434   /// can't be equal and will solve setcc instructions when possible.
435   class PredicateSimplifier : public FunctionPass {
436   public:
437     bool runOnFunction(Function &F);
438     virtual void getAnalysisUsage(AnalysisUsage &AU) const;
439
440   private:
441     // Try to replace the Use of the instruction with something simpler.
442     Value *resolve(SetCondInst *SCI, const PropertySet &);
443     Value *resolve(BinaryOperator *BO, const PropertySet &);
444     Value *resolve(SelectInst *SI, const PropertySet &);
445     Value *resolve(Value *V, const PropertySet &);
446
447     // Used by terminator instructions to proceed from the current basic
448     // block to the next. Verifies that "current" dominates "next",
449     // then calls visitBasicBlock.
450     void proceedToSuccessor(PropertySet &CurrentPS, PropertySet &NextPS,
451                             DTNodeType *Current, DTNodeType *Next);
452     void proceedToSuccessor(PropertySet &CurrentPS,
453                             DTNodeType *Current, DTNodeType *Next);
454
455     // Visits each instruction in the basic block.
456     void visitBasicBlock(DTNodeType *DTNode, PropertySet &KnownProperties);
457
458     // Tries to simplify each Instruction and add new properties to
459     // the PropertySet. Returns true if it erase the instruction.
460     void visitInstruction(Instruction *I, DTNodeType *, PropertySet &);
461     // For each instruction, add the properties to KnownProperties.
462
463     void visit(TerminatorInst *TI, DTNodeType *, PropertySet &);
464     void visit(BranchInst *BI, DTNodeType *, PropertySet &);
465     void visit(SwitchInst *SI, DTNodeType *, PropertySet);
466     void visit(LoadInst *LI, DTNodeType *, PropertySet &);
467     void visit(StoreInst *SI, DTNodeType *, PropertySet &);
468     void visit(BinaryOperator *BO, DTNodeType *, PropertySet &);
469
470     DominatorTree *DT;
471     bool modified;
472   };
473
474   RegisterPass<PredicateSimplifier> X("predsimplify",
475                                       "Predicate Simplifier");
476 }
477
478 FunctionPass *llvm::createPredicateSimplifierPass() {
479   return new PredicateSimplifier();
480 }
481
482 bool PredicateSimplifier::runOnFunction(Function &F) {
483   DT = &getAnalysis<DominatorTree>();
484
485   modified = false;
486   PropertySet KnownProperties;
487   visitBasicBlock(DT->getRootNode(), KnownProperties);
488   return modified;
489 }
490
491 void PredicateSimplifier::getAnalysisUsage(AnalysisUsage &AU) const {
492   AU.addRequired<DominatorTree>();
493 }
494
495 // resolve catches cases addProperty won't because it wasn't used as a
496 // condition in the branch, and that visit won't, because the instruction
497 // was defined outside of the scope that the properties apply to.
498 Value *PredicateSimplifier::resolve(SetCondInst *SCI,
499                                     const PropertySet &KP) {
500   // Attempt to resolve the SetCondInst to a boolean.
501
502   Value *SCI0 = resolve(SCI->getOperand(0), KP),
503         *SCI1 = resolve(SCI->getOperand(1), KP);
504
505   ConstantIntegral *CI1 = dyn_cast<ConstantIntegral>(SCI0),
506                    *CI2 = dyn_cast<ConstantIntegral>(SCI1);
507
508   if (!CI1 || !CI2) {
509     PropertySet::ConstPropertyIterator NE =
510         KP.findProperty(PropertySet::NE, SCI0, SCI1);
511
512     if (NE != KP.Properties.end()) {
513       switch (SCI->getOpcode()) {
514         case Instruction::SetEQ:
515           return ConstantBool::False;
516         case Instruction::SetNE:
517           return ConstantBool::True;
518         case Instruction::SetLE:
519         case Instruction::SetGE:
520         case Instruction::SetLT:
521         case Instruction::SetGT:
522           break;
523         default:
524           assert(0 && "Unknown opcode in SetCondInst.");
525           break;
526       }
527     }
528     return SCI;
529   }
530
531   switch(SCI->getOpcode()) {
532     case Instruction::SetLE:
533     case Instruction::SetGE:
534     case Instruction::SetEQ:
535       if (CI1->getRawValue() == CI2->getRawValue())
536         return ConstantBool::True;
537       else
538         return ConstantBool::False;
539     case Instruction::SetLT:
540     case Instruction::SetGT:
541     case Instruction::SetNE:
542       if (CI1->getRawValue() == CI2->getRawValue())
543         return ConstantBool::False;
544       else
545         return ConstantBool::True;
546     default:
547       assert(0 && "Unknown opcode in SetContInst.");
548       break;
549   }
550 }
551
552 Value *PredicateSimplifier::resolve(BinaryOperator *BO,
553                                     const PropertySet &KP) {
554   if (SetCondInst *SCI = dyn_cast<SetCondInst>(BO))
555     return resolve(SCI, KP);
556
557   Value *lhs = resolve(BO->getOperand(0), KP),
558         *rhs = resolve(BO->getOperand(1), KP);
559   ConstantIntegral *CI1 = dyn_cast<ConstantIntegral>(lhs);
560   ConstantIntegral *CI2 = dyn_cast<ConstantIntegral>(rhs);
561
562   if (!CI1 || !CI2) return BO;
563
564   Value *V = ConstantExpr::get(BO->getOpcode(), CI1, CI2);
565   if (V) return V;
566   return BO;
567 }
568
569 Value *PredicateSimplifier::resolve(SelectInst *SI, const PropertySet &KP) {
570   Value *Condition = resolve(SI->getCondition(), KP);
571   if (Condition == ConstantBool::True)
572     return resolve(SI->getTrueValue(), KP);
573   else if (Condition == ConstantBool::False)
574     return resolve(SI->getFalseValue(), KP);
575   return SI;
576 }
577
578 Value *PredicateSimplifier::resolve(Value *V, const PropertySet &KP) {
579   if (isa<Constant>(V) || isa<BasicBlock>(V) || KP.empty()) return V;
580
581   V = KP.canonicalize(V);
582
583   DEBUG(std::cerr << "peering into " << *V << "\n");
584
585   if (BinaryOperator *BO = dyn_cast<BinaryOperator>(V))
586     return resolve(BO, KP);
587   else if (SelectInst *SI = dyn_cast<SelectInst>(V))
588     return resolve(SI, KP);
589
590   return V;
591 }
592
593 void PredicateSimplifier::visitBasicBlock(DTNodeType *DTNode,
594                                           PropertySet &KnownProperties) {
595   BasicBlock *BB = DTNode->getBlock();
596   for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) {
597     visitInstruction(I, DTNode, KnownProperties);
598   }
599 }
600
601 void PredicateSimplifier::visitInstruction(Instruction *I,
602                                            DTNodeType *DTNode,
603                                            PropertySet &KnownProperties) {
604
605   DEBUG(std::cerr << "Considering instruction " << *I << "\n");
606   DEBUG(KnownProperties.debug(std::cerr));
607
608   // Try to replace the whole instruction.
609   Value *V = resolve(I, KnownProperties);
610   assert(V && "resolve not supposed to return NULL.");
611   if (V != I) {
612     modified = true;
613     ++NumInstruction;
614     I->replaceAllUsesWith(V);
615     return;
616   }
617
618   // Try to substitute operands.
619   for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
620     Value *Oper = I->getOperand(i);
621     Value *V = resolve(Oper, KnownProperties);
622     assert(V && "resolve not supposed to return NULL.");
623     if (V != Oper) {
624       modified = true;
625       ++NumVarsReplaced;
626       DEBUG(std::cerr << "resolving " << *I);
627       I->setOperand(i, V);
628       DEBUG(std::cerr << "into " << *I);
629     }
630   }
631
632   if (TerminatorInst *TI = dyn_cast<TerminatorInst>(I))
633     visit(TI, DTNode, KnownProperties);
634   else if (LoadInst *LI = dyn_cast<LoadInst>(I))
635     visit(LI, DTNode, KnownProperties);
636   else if (StoreInst *SI = dyn_cast<StoreInst>(I))
637     visit(SI, DTNode, KnownProperties);
638   else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(I))
639     visit(BO, DTNode, KnownProperties);
640 }
641
642 void PredicateSimplifier::proceedToSuccessor(PropertySet &CurrentPS,
643                                              PropertySet &NextPS,
644                                              DTNodeType *Current,
645                                              DTNodeType *Next) {
646   if (Next->getBlock()->getSinglePredecessor() == Current->getBlock())
647     proceedToSuccessor(NextPS, Current, Next);
648   else
649     proceedToSuccessor(CurrentPS, Current, Next);
650 }
651
652 void PredicateSimplifier::proceedToSuccessor(PropertySet &KP,
653                                              DTNodeType *Current,
654                                              DTNodeType *Next) {
655   if (Current->properlyDominates(Next))
656     visitBasicBlock(Next, KP);
657 }
658
659 void PredicateSimplifier::visit(TerminatorInst *TI, DTNodeType *Node,
660                                 PropertySet &KP) {
661   if (BranchInst *BI = dyn_cast<BranchInst>(TI)) {
662     visit(BI, Node, KP);
663     return;
664   }
665   if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {
666     visit(SI, Node, KP);
667     return;
668   }
669
670   for (unsigned i = 0, E = TI->getNumSuccessors(); i != E; ++i) {
671     BasicBlock *BB = TI->getSuccessor(i);
672     PropertySet KPcopy(KP);
673     proceedToSuccessor(KPcopy, Node, DT->getNode(TI->getSuccessor(i)));
674   }
675 }
676
677 void PredicateSimplifier::visit(BranchInst *BI, DTNodeType *Node,
678                                 PropertySet &KP) {
679   if (BI->isUnconditional()) {
680     proceedToSuccessor(KP, Node, DT->getNode(BI->getSuccessor(0)));
681     return;
682   }
683
684   Value *Condition = BI->getCondition();
685
686   BasicBlock *TrueDest  = BI->getSuccessor(0),
687              *FalseDest = BI->getSuccessor(1);
688
689   if (Condition == ConstantBool::True) {
690     FalseDest->removePredecessor(BI->getParent());
691     BI->setUnconditionalDest(TrueDest);
692     modified = true;
693     ++NumBranches;
694     proceedToSuccessor(KP, Node, DT->getNode(TrueDest));
695     return;
696   } else if (Condition == ConstantBool::False) {
697     TrueDest->removePredecessor(BI->getParent());
698     BI->setUnconditionalDest(FalseDest);
699     modified = true;
700     ++NumBranches;
701     proceedToSuccessor(KP, Node, DT->getNode(FalseDest));
702     return;
703   }
704
705   PropertySet TrueProperties(KP), FalseProperties(KP);
706   DEBUG(std::cerr << "true set:\n");
707   TrueProperties.addEqual(ConstantBool::True,   Condition);
708   DEBUG(TrueProperties.debug(std::cerr));
709   DEBUG(std::cerr << "false set:\n");
710   FalseProperties.addEqual(ConstantBool::False, Condition);
711   DEBUG(FalseProperties.debug(std::cerr));
712
713   PropertySet KPcopy(KP);
714   proceedToSuccessor(KP,     TrueProperties,  Node, DT->getNode(TrueDest));
715   proceedToSuccessor(KPcopy, FalseProperties, Node, DT->getNode(FalseDest));
716 }
717
718 void PredicateSimplifier::visit(SwitchInst *SI, DTNodeType *DTNode,
719                                 PropertySet KP) {
720   Value *Condition = SI->getCondition();
721   assert(Condition == KP.canonicalize(Condition) &&
722          "Instruction wasn't already canonicalized?");
723
724   // If there's an NEProperty covering this SwitchInst, we may be able to
725   // eliminate one of the cases.
726   for (PropertySet::ConstPropertyIterator I = KP.Properties.begin(),
727        E = KP.Properties.end(); I != E; ++I) {
728     if (I->Opcode != PropertySet::NE) continue;
729     Value *V1 = KP.union_find.getLeader(I->I1),
730           *V2 = KP.union_find.getLeader(I->I2);
731
732     // Find a Property with a ConstantInt on one side and our
733     // Condition on the other.
734     ConstantInt *CI = NULL;
735     if (V1 == Condition)
736       CI = dyn_cast<ConstantInt>(V2);
737     else if (V2 == Condition)
738       CI = dyn_cast<ConstantInt>(V1);
739
740     if (!CI) continue;
741
742     unsigned i = SI->findCaseValue(CI);
743     if (i != 0) { // zero is reserved for the default case.
744       SI->getSuccessor(i)->removePredecessor(SI->getParent());
745       SI->removeCase(i);
746       modified = true;
747       ++NumSwitchCases;
748     }
749   }
750
751   // Set the EQProperty in each of the cases BBs,
752   // and the NEProperties in the default BB.
753   PropertySet DefaultProperties(KP);
754
755   DTNodeType *Node        = DT->getNode(SI->getParent()),
756              *DefaultNode = DT->getNode(SI->getSuccessor(0));
757   if (!Node->dominates(DefaultNode)) DefaultNode = NULL;
758
759   for (unsigned I = 1, E = SI->getNumCases(); I < E; ++I) {
760     ConstantInt *CI = SI->getCaseValue(I);
761
762     BasicBlock *SuccBB = SI->getSuccessor(I);
763     PropertySet copy(KP);
764     if (SuccBB->getSinglePredecessor()) {
765       PropertySet NewProperties(KP);
766       NewProperties.addEqual(Condition, CI);
767       proceedToSuccessor(copy, NewProperties, DTNode, DT->getNode(SuccBB));
768     } else
769       proceedToSuccessor(copy, DTNode, DT->getNode(SuccBB));
770
771     if (DefaultNode)
772       DefaultProperties.addNotEqual(Condition, CI);
773   }
774
775   if (DefaultNode)
776     proceedToSuccessor(DefaultProperties, DTNode, DefaultNode);
777 }
778
779 void PredicateSimplifier::visit(LoadInst *LI, DTNodeType *,
780                                 PropertySet &KP) {
781   Value *Ptr = LI->getPointerOperand();
782   KP.addNotEqual(Constant::getNullValue(Ptr->getType()), Ptr);
783 }
784
785 void PredicateSimplifier::visit(StoreInst *SI, DTNodeType *,
786                                 PropertySet &KP) {
787   Value *Ptr = SI->getPointerOperand();
788   KP.addNotEqual(Constant::getNullValue(Ptr->getType()), Ptr);
789 }
790
791 void PredicateSimplifier::visit(BinaryOperator *BO, DTNodeType *,
792                                 PropertySet &KP) {
793   Instruction::BinaryOps ops = BO->getOpcode();
794
795   switch (ops) {
796     case Instruction::Div:
797     case Instruction::Rem: {
798       Value *Divisor = BO->getOperand(1);
799       KP.addNotEqual(Constant::getNullValue(Divisor->getType()), Divisor);
800       break;
801     }
802     default:
803       break;
804   }
805 }