1 //===-- PredicateSimplifier.cpp - Path Sensitive Simplifier ---------------===//
3 // The LLVM Compiler Infrastructure
5 // This file was developed by Nick Lewycky and is distributed under the
6 // University of Illinois Open Source License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // Path-sensitive optimizer. In a branch where x == y, replace uses of
11 // x with y. Permits further optimization, such as the elimination of
12 // the unreachable call:
14 // void test(int *p, int *q)
20 // foo(); // unreachable
23 //===----------------------------------------------------------------------===//
25 // This pass focusses on four properties; equals, not equals, less-than
26 // and less-than-or-equals-to. The greater-than forms are also held just
27 // to allow walking from a lesser node to a greater one. These properties
28 // are stored in a lattice; LE can become LT or EQ, NE can become LT or GT.
30 // These relationships define a graph between values of the same type. Each
31 // Value is stored in a map table that retrieves the associated Node. This
32 // is how EQ relationships are stored; the map contains pointers to the
33 // same node. The node contains a most canonical Value* form and the list of
34 // known relationships.
36 // If two nodes are known to be inequal, then they will contain pointers to
37 // each other with an "NE" relationship. If node getNode(%x) is less than
38 // getNode(%y), then the %x node will contain <%y, GT> and %y will contain
39 // <%x, LT>. This allows us to tie nodes together into a graph like this:
43 // with four nodes representing the properties. The InequalityGraph provides
44 // queries (such as "isEqual") and mutators (such as "addEqual"). To implement
45 // "isLess(%a, %c)", we start with getNode(%c) and walk downwards until
46 // we reach %a or the leaf node. Note that the graph is directed and acyclic,
47 // but may contain joins, meaning that this walk is not a linear time
50 // To create these properties, we wait until a branch or switch instruction
51 // implies that a particular value is true (or false). The VRPSolver is
52 // responsible for analyzing the variable and seeing what new inferences
53 // can be made from each property. For example:
55 // %P = seteq int* %ptr, null
56 // %a = or bool %P, %Q
57 // br bool %a label %cond_true, label %cond_false
59 // For the true branch, the VRPSolver will start with %a EQ true and look at
60 // the definition of %a and find that it can infer that %P and %Q are both
61 // true. From %P being true, it can infer that %ptr NE null. For the false
62 // branch it can't infer anything from the "or" instruction.
64 // Besides branches, we can also infer properties from instruction that may
65 // have undefined behaviour in certain cases. For example, the dividend of
66 // a division may never be zero. After the division instruction, we may assume
67 // that the dividend is not equal to zero.
69 //===----------------------------------------------------------------------===//
71 #define DEBUG_TYPE "predsimplify"
72 #include "llvm/Transforms/Scalar.h"
73 #include "llvm/Constants.h"
74 #include "llvm/DerivedTypes.h"
75 #include "llvm/Instructions.h"
76 #include "llvm/Pass.h"
77 #include "llvm/ADT/SetOperations.h"
78 #include "llvm/ADT/SmallVector.h"
79 #include "llvm/ADT/Statistic.h"
80 #include "llvm/ADT/STLExtras.h"
81 #include "llvm/Analysis/Dominators.h"
82 #include "llvm/Analysis/ET-Forest.h"
83 #include "llvm/Assembly/Writer.h"
84 #include "llvm/Support/CFG.h"
85 #include "llvm/Support/Compiler.h"
86 #include "llvm/Support/Debug.h"
87 #include "llvm/Support/InstVisitor.h"
88 #include "llvm/Transforms/Utils/Local.h"
95 STATISTIC(NumVarsReplaced, "Number of argument substitutions");
96 STATISTIC(NumInstruction , "Number of instructions removed");
97 STATISTIC(NumSimple , "Number of simple replacements");
100 /// The InequalityGraph stores the relationships between values.
101 /// Each Value in the graph is assigned to a Node. Nodes are pointer
102 /// comparable for equality. The caller is expected to maintain the logical
103 /// consistency of the system.
105 /// The InequalityGraph class may invalidate Node*s after any mutator call.
106 /// @brief The InequalityGraph stores the relationships between values.
107 class VISIBILITY_HIDDEN InequalityGraph {
112 // 0 0 0 -- invalid (false)
113 // 0 0 1 -- invalid (EQ)
119 // 1 1 1 -- invalid (true)
121 EQ_BIT = 1, GT_BIT = 2, LT_BIT = 4
124 GT = GT_BIT, GE = GT_BIT | EQ_BIT,
125 LT = LT_BIT, LE = LT_BIT | EQ_BIT,
129 static bool validPredicate(LatticeVal LV) {
130 return LV > 1 && LV < 7;
134 typedef std::map<Value *, Node *> NodeMapType;
137 const InequalityGraph *ConcreteIG;
140 /// A single node in the InequalityGraph. This stores the canonical Value
141 /// for the node, as well as the relationships with the neighbours.
143 /// Because the lists are intended to be used for traversal, it is invalid
144 /// for the node to list itself in LessEqual or GreaterEqual lists. The
145 /// fact that a node is equal to itself is implied, and may be checked
146 /// with pointer comparison.
147 /// @brief A single node in the InequalityGraph.
148 class VISIBILITY_HIDDEN Node {
149 friend class InequalityGraph;
153 typedef SmallVector<std::pair<Node *, LatticeVal>, 4> RelationsType;
154 RelationsType Relations;
156 typedef RelationsType::iterator iterator;
157 typedef RelationsType::const_iterator const_iterator;
160 /// Updates the lattice value for a given node. Create a new entry if
161 /// one doesn't exist, otherwise it merges the values. The new lattice
162 /// value must not be inconsistent with any previously existing value.
163 void update(Node *N, LatticeVal R) {
164 iterator I = find(N);
166 Relations.push_back(std::make_pair(N, R));
168 I->second = static_cast<LatticeVal>(I->second & R);
169 assert(validPredicate(I->second) &&
170 "Invalid union of lattice values.");
174 void assign(Node *N, LatticeVal R) {
175 iterator I = find(N);
176 if (I != end()) I->second = R;
178 Relations.push_back(std::make_pair(N, R));
182 iterator begin() { return Relations.begin(); }
183 iterator end() { return Relations.end(); }
184 iterator find(Node *N) {
185 iterator I = begin();
186 for (iterator E = end(); I != E; ++I)
187 if (I->first == N) break;
191 const_iterator begin() const { return Relations.begin(); }
192 const_iterator end() const { return Relations.end(); }
193 const_iterator find(Node *N) const {
194 const_iterator I = begin();
195 for (const_iterator E = end(); I != E; ++I)
196 if (I->first == N) break;
200 unsigned findIndex(Node *N) {
202 iterator I = begin();
203 for (iterator E = end(); I != E; ++I, ++i)
204 if (I->first == N) return i;
208 void erase(iterator i) { Relations.erase(i); }
210 Value *getValue() const { return Canonical; }
211 void setValue(Value *V) { Canonical = V; }
213 void addNotEqual(Node *N) { update(N, NE); }
214 void addLess(Node *N) { update(N, LT); }
215 void addLessEqual(Node *N) { update(N, LE); }
216 void addGreater(Node *N) { update(N, GT); }
217 void addGreaterEqual(Node *N) { update(N, GE); }
220 InequalityGraph() : ConcreteIG(NULL) {}
222 InequalityGraph(const InequalityGraph &_IG) {
224 if (_IG.ConcreteIG) ConcreteIG = _IG.ConcreteIG;
225 else ConcreteIG = &_IG;
238 /// If the Value is in the graph, return the canonical form. Otherwise,
239 /// return the original Value.
240 Value *canonicalize(Value *V) const {
241 if (const Node *N = getNode(V))
242 return N->getValue();
247 /// Returns the node currently representing Value V, or null if no such
249 Node *getNode(Value *V) {
252 NodeMapType::const_iterator I = Nodes.find(V);
253 return (I != Nodes.end()) ? I->second : 0;
256 const Node *getNode(Value *V) const {
257 if (ConcreteIG) return ConcreteIG->getNode(V);
259 NodeMapType::const_iterator I = Nodes.find(V);
260 return (I != Nodes.end()) ? I->second : 0;
263 Node *getOrInsertNode(Value *V) {
264 if (Node *N = getNode(V))
270 Node *newNode(Value *V) {
271 //DOUT << "new node: " << *V << "\n";
274 assert(N == 0 && "Node already exists for value.");
280 /// Returns true iff the nodes are provably inequal.
281 bool isNotEqual(const Node *N1, const Node *N2) const {
282 if (N1 == N2) return false;
283 for (Node::const_iterator I = N1->begin(), E = N1->end(); I != E; ++I) {
285 return (I->second & EQ_BIT) == 0;
287 return isLess(N1, N2) || isGreater(N1, N2);
290 /// Returns true iff N1 is provably less than N2.
291 bool isLess(const Node *N1, const Node *N2) const {
292 if (N1 == N2) return false;
293 for (Node::const_iterator I = N2->begin(), E = N2->end(); I != E; ++I) {
295 return I->second == LT;
297 for (Node::const_iterator I = N2->begin(), E = N2->end(); I != E; ++I) {
298 if ((I->second & (LT_BIT | GT_BIT)) == LT_BIT)
299 if (isLess(N1, I->first)) return true;
304 /// Returns true iff N1 is provably less than or equal to N2.
305 bool isLessEqual(const Node *N1, const Node *N2) const {
306 if (N1 == N2) return true;
307 for (Node::const_iterator I = N2->begin(), E = N2->end(); I != E; ++I) {
309 return (I->second & (LT_BIT | GT_BIT)) == LT_BIT;
311 for (Node::const_iterator I = N2->begin(), E = N2->end(); I != E; ++I) {
312 if ((I->second & (LT_BIT | GT_BIT)) == LT_BIT)
313 if (isLessEqual(N1, I->first)) return true;
318 /// Returns true iff N1 is provably greater than N2.
319 bool isGreater(const Node *N1, const Node *N2) const {
320 return isLess(N2, N1);
323 /// Returns true iff N1 is provably greater than or equal to N2.
324 bool isGreaterEqual(const Node *N1, const Node *N2) const {
325 return isLessEqual(N2, N1);
328 // The add* methods assume that your input is logically valid and may
329 // assertion-fail or infinitely loop if you attempt a contradiction.
331 void addEqual(Node *N, Value *V) {
336 void addNotEqual(Node *N1, Node *N2) {
337 assert(N1 != N2 && "A node can't be inequal to itself.");
343 /// N1 is less than N2.
344 void addLess(Node *N1, Node *N2) {
345 assert(N1 != N2 && !isLess(N2, N1) && "Attempt to create < cycle.");
351 /// N1 is less than or equal to N2.
352 void addLessEqual(Node *N1, Node *N2) {
353 assert(N1 != N2 && "Nodes are equal. Use mergeNodes instead.");
354 assert(!isGreater(N1, N2) && "Impossible: Adding x <= y when x > y.");
356 N2->addLessEqual(N1);
357 N1->addGreaterEqual(N2);
360 /// Find the transitive closure starting at a node walking down the edges
361 /// of type Val. Type Inserter must be an inserter that accepts Node *.
362 template <typename Inserter>
363 void transitiveClosure(Node *N, LatticeVal Val, Inserter insert) {
364 for (Node::iterator I = N->begin(), E = N->end(); I != E; ++I) {
365 if (I->second == Val) {
367 transitiveClosure(I->first, Val, insert);
372 /// Kills off all the nodes in Kill by replicating their properties into
373 /// node N. The elements of Kill must be unique. After merging, N's new
374 /// canonical value is NewCanonical. Type C must be a container of Node *.
375 template <typename C>
376 void mergeNodes(Node *N, C &Kill, Value *NewCanonical);
378 /// Removes a Value from the graph, but does not delete any nodes. As this
379 /// method does not delete Nodes, V may not be the canonical choice for
381 void remove(Value *V) {
384 for (NodeMapType::iterator I = Nodes.begin(), E = Nodes.end(); I != E;) {
385 NodeMapType::iterator J = I++;
386 assert(J->second->getValue() != V && "Can't delete canonical choice.");
387 if (J->first == V) Nodes.erase(J);
392 void debug(std::ostream &os) const {
393 std::set<Node *> VisitedNodes;
394 for (NodeMapType::const_iterator I = Nodes.begin(), E = Nodes.end();
397 os << *I->first << " == " << *N->getValue() << "\n";
398 if (VisitedNodes.insert(N).second) {
399 os << *N->getValue() << ":\n";
400 for (Node::const_iterator NI = N->begin(), NE = N->end();
402 static const std::string names[8] =
403 { "00", "01", " <", "<=", " >", ">=", "!=", "07" };
404 os << " " << names[NI->second] << " "
405 << *NI->first->getValue() << "\n";
413 InequalityGraph::~InequalityGraph() {
414 if (ConcreteIG) return;
416 std::vector<Node *> Remove;
417 for (NodeMapType::iterator I = Nodes.begin(), E = Nodes.end();
419 if (I->first == I->second->getValue())
420 Remove.push_back(I->second);
422 for (std::vector<Node *>::iterator I = Remove.begin(), E = Remove.end();
428 template <typename C>
429 void InequalityGraph::mergeNodes(Node *N, C &Kill, Value *NewCanonical) {
432 // Merge the relationships from the members of Kill into N.
433 for (typename C::iterator KI = Kill.begin(), KE = Kill.end();
436 for (Node::iterator I = (*KI)->begin(), E = (*KI)->end(); I != E; ++I) {
437 if (I->first == N) continue;
439 Node::iterator NI = N->find(I->first);
440 if (NI == N->end()) {
441 N->Relations.push_back(std::make_pair(I->first, I->second));
443 unsigned char LV = NI->second & I->second;
446 assert(std::find(Kill.begin(), Kill.end(), I->first) != Kill.end()
447 && "Lost EQ property.");
450 NI->second = static_cast<LatticeVal>(LV);
451 assert(InequalityGraph::validPredicate(NI->second) &&
452 "Invalid union of lattice values.");
456 // All edges are reciprocal; every Node that Kill points to also
457 // contains a pointer to Kill. Replace those with pointers with N.
458 unsigned iter = I->first->findIndex(*KI);
459 assert(iter != (unsigned)-1 && "Edge not reciprocal.");
460 I->first->assign(N, (I->first->begin()+iter)->second);
461 I->first->erase(I->first->begin()+iter);
464 // Removing references from N to Kill.
465 Node::iterator NI = N->find(*KI);
466 if (NI != N->end()) {
467 N->erase(NI); // breaks reciprocity until Kill is deleted.
471 N->setValue(NewCanonical);
473 // Update value mapping to point to the merged node.
474 for (NodeMapType::iterator I = Nodes.begin(), E = Nodes.end();
476 if (std::find(Kill.begin(), Kill.end(), I->second) != Kill.end())
480 for (typename C::iterator KI = Kill.begin(), KE = Kill.end();
486 void InequalityGraph::materialize() {
487 if (!ConcreteIG) return;
488 const InequalityGraph *IG = ConcreteIG;
491 for (NodeMapType::const_iterator I = IG->Nodes.begin(),
492 E = IG->Nodes.end(); I != E; ++I) {
493 if (I->first == I->second->getValue()) {
494 Node *N = newNode(I->first);
495 N->Relations.reserve(N->Relations.size());
498 for (NodeMapType::const_iterator I = IG->Nodes.begin(),
499 E = IG->Nodes.end(); I != E; ++I) {
500 if (I->first != I->second->getValue()) {
501 Nodes[I->first] = getNode(I->second->getValue());
503 Node *Old = I->second;
504 Node *N = getNode(I->first);
505 for (Node::const_iterator NI = Old->begin(), NE = Old->end();
507 N->assign(getNode(NI->first->getValue()), NI->second);
513 /// VRPSolver keeps track of how changes to one variable affect other
514 /// variables, and forwards changes along to the InequalityGraph. It
515 /// also maintains the correct choice for "canonical" in the IG.
516 /// @brief VRPSolver calculates inferences from a new relationship.
517 class VISIBILITY_HIDDEN VRPSolver {
519 std::deque<Instruction *> WorkList;
522 const InequalityGraph &cIG;
526 typedef InequalityGraph::Node Node;
528 /// Returns true if V1 is a better canonical value than V2.
529 bool compare(Value *V1, Value *V2) const {
530 if (isa<Constant>(V1))
531 return !isa<Constant>(V2);
532 else if (isa<Constant>(V2))
534 else if (isa<Argument>(V1))
535 return !isa<Argument>(V2);
536 else if (isa<Argument>(V2))
539 Instruction *I1 = dyn_cast<Instruction>(V1);
540 Instruction *I2 = dyn_cast<Instruction>(V2);
542 if (!I1 || !I2) return false;
544 BasicBlock *BB1 = I1->getParent(),
545 *BB2 = I2->getParent();
547 for (BasicBlock::const_iterator I = BB1->begin(), E = BB1->end();
549 if (&*I == I1) return true;
550 if (&*I == I2) return false;
552 assert(!"Instructions not found in parent BasicBlock?");
554 return Forest->properlyDominates(BB1, BB2);
559 void addToWorklist(Instruction *I) {
560 //DOUT << "addToWorklist: " << *I << "\n";
562 if (!isa<BinaryOperator>(I) && !isa<SelectInst>(I)) return;
564 const Type *Ty = I->getType();
565 if (Ty == Type::VoidTy || Ty->isFPOrFPVector()) return;
567 if (isInstructionTriviallyDead(I)) return;
569 WorkList.push_back(I);
572 void addRecursive(Value *V) {
573 //DOUT << "addRecursive: " << *V << "\n";
575 Instruction *I = dyn_cast<Instruction>(V);
578 else if (!isa<Argument>(V))
581 //DOUT << "addRecursive uses...\n";
582 for (Value::use_iterator UI = V->use_begin(), UE = V->use_end();
584 // Use must be either be dominated by Top, or dominate Top.
585 if (Instruction *Inst = dyn_cast<Instruction>(*UI)) {
586 ETNode *INode = Forest->getNodeForBlock(Inst->getParent());
587 if (INode->DominatedBy(Top) || Top->DominatedBy(INode))
593 //DOUT << "addRecursive ops...\n";
594 for (User::op_iterator OI = I->op_begin(), OE = I->op_end();
596 if (Instruction *Inst = dyn_cast<Instruction>(*OI))
600 //DOUT << "exit addRecursive (" << *V << ").\n";
604 VRPSolver(InequalityGraph &IG, ETForest *Forest, BasicBlock *TopBB)
605 : IG(IG), cIG(IG), Forest(Forest), Top(Forest->getNodeForBlock(TopBB)) {}
607 bool isEqual(Value *V1, Value *V2) const {
608 if (V1 == V2) return true;
609 if (const Node *N1 = cIG.getNode(V1))
610 return N1 == cIG.getNode(V2);
614 bool isNotEqual(Value *V1, Value *V2) const {
615 if (V1 == V2) return false;
616 if (const Node *N1 = cIG.getNode(V1))
617 if (const Node *N2 = cIG.getNode(V2))
618 return cIG.isNotEqual(N1, N2);
622 bool isLess(Value *V1, Value *V2) const {
623 if (V1 == V2) return false;
624 if (const Node *N1 = cIG.getNode(V1))
625 if (const Node *N2 = cIG.getNode(V2))
626 return cIG.isLess(N1, N2);
630 bool isLessEqual(Value *V1, Value *V2) const {
631 if (V1 == V2) return true;
632 if (const Node *N1 = cIG.getNode(V1))
633 if (const Node *N2 = cIG.getNode(V2))
634 return cIG.isLessEqual(N1, N2);
638 bool isGreater(Value *V1, Value *V2) const {
639 if (V1 == V2) return false;
640 if (const Node *N1 = cIG.getNode(V1))
641 if (const Node *N2 = cIG.getNode(V2))
642 return cIG.isGreater(N1, N2);
646 bool isGreaterEqual(Value *V1, Value *V2) const {
647 if (V1 == V2) return true;
648 if (const Node *N1 = IG.getNode(V1))
649 if (const Node *N2 = IG.getNode(V2))
650 return cIG.isGreaterEqual(N1, N2);
654 // All of the add* functions return true if the InequalityGraph represents
655 // the property, and false if there is a logical contradiction. On false,
656 // you may no longer perform any queries on the InequalityGraph.
658 bool addEqual(Value *V1, Value *V2) {
659 //DOUT << "addEqual(" << *V1 << ", " << *V2 << ")\n";
660 if (isEqual(V1, V2)) return true;
662 const Node *cN1 = cIG.getNode(V1), *cN2 = cIG.getNode(V2);
664 if (cN1 && cN2 && cIG.isNotEqual(cN1, cN2))
667 if (compare(V2, V1)) { std::swap(V1, V2); std::swap(cN1, cN2); }
670 if (ConstantBool *CB = dyn_cast<ConstantBool>(V1)) {
671 Node *N1 = IG.getNode(V1);
673 // When "addEqual" is performed and the new value is a ConstantBool,
674 // iterate through the NE set and fix them up to be EQ of the
677 for (Node::iterator I = N1->begin(), E = N1->end(); I != E; ++I)
678 if ((I->second & 1) == 0) {
679 assert(N1 != I->first && "Node related to itself?");
680 addEqual(I->first->getValue(),
681 ConstantBool::get(!CB->getValue()));
687 if (Instruction *I2 = dyn_cast<Instruction>(V2)) {
688 ETNode *Node_I2 = Forest->getNodeForBlock(I2->getParent());
689 if (Top != Node_I2 && Node_I2->DominatedBy(Top)) {
691 if (cN1 && compare(V1, cN1->getValue())) V = cN1->getValue();
692 //DOUT << "Simply removing " << *I2
693 // << ", replacing with " << *V << "\n";
694 I2->replaceAllUsesWith(V);
695 // leave it dead; it'll get erased later.
703 Node *N1 = IG.getNode(V1), *N2 = IG.getNode(V2);
707 if (compare(V1, N1->getValue())) N1->setValue(V1);
711 if (compare(V1, N2->getValue())) N2->setValue(V1);
714 // Suppose we're being told that %x == %y, and %x <= %z and %y >= %z.
715 // We can't just merge %x and %y because the relationship with %z would
716 // be EQ and that's invalid; they need to be the same Node.
718 // What we're doing is looking for any chain of nodes reaching %z such
719 // that %x <= %z and %y >= %z, and vice versa. The cool part is that
720 // every node in between is also equal because of the squeeze principle.
722 std::vector<Node *> N1_GE, N2_LE, N1_LE, N2_GE;
723 IG.transitiveClosure(N1, InequalityGraph::GE, back_inserter(N1_GE));
724 std::sort(N1_GE.begin(), N1_GE.end());
725 N1_GE.erase(std::unique(N1_GE.begin(), N1_GE.end()), N1_GE.end());
726 IG.transitiveClosure(N2, InequalityGraph::LE, back_inserter(N2_LE));
727 std::sort(N1_LE.begin(), N1_LE.end());
728 N1_LE.erase(std::unique(N1_LE.begin(), N1_LE.end()), N1_LE.end());
729 IG.transitiveClosure(N1, InequalityGraph::LE, back_inserter(N1_LE));
730 std::sort(N2_GE.begin(), N2_GE.end());
731 N2_GE.erase(std::unique(N2_GE.begin(), N2_GE.end()), N2_GE.end());
732 std::unique(N2_GE.begin(), N2_GE.end());
733 IG.transitiveClosure(N2, InequalityGraph::GE, back_inserter(N2_GE));
734 std::sort(N2_LE.begin(), N2_LE.end());
735 N2_LE.erase(std::unique(N2_LE.begin(), N2_LE.end()), N2_LE.end());
737 std::vector<Node *> Set1, Set2;
738 std::set_intersection(N1_GE.begin(), N1_GE.end(),
739 N2_LE.begin(), N2_LE.end(),
740 back_inserter(Set1));
741 std::set_intersection(N1_LE.begin(), N1_LE.end(),
742 N2_GE.begin(), N2_GE.end(),
743 back_inserter(Set2));
745 std::vector<Node *> Equal;
746 std::set_union(Set1.begin(), Set1.end(), Set2.begin(), Set2.end(),
747 back_inserter(Equal));
749 Value *Best = N1->getValue();
750 if (compare(N2->getValue(), Best)) Best = N2->getValue();
752 for (std::vector<Node *>::iterator I = Equal.begin(), E = Equal.end();
754 Value *V = (*I)->getValue();
755 if (compare(V, Best)) Best = V;
759 IG.mergeNodes(N1, Equal, Best);
761 if (!N1 && !N2) IG.addEqual(IG.newNode(V1), V2);
769 bool addNotEqual(Value *V1, Value *V2) {
770 //DOUT << "addNotEqual(" << *V1 << ", " << *V2 << ")\n");
771 if (isNotEqual(V1, V2)) return true;
773 // Never permit %x NE true/false.
774 if (ConstantBool *B1 = dyn_cast<ConstantBool>(V1)) {
775 return addEqual(ConstantBool::get(!B1->getValue()), V2);
776 } else if (ConstantBool *B2 = dyn_cast<ConstantBool>(V2)) {
777 return addEqual(V1, ConstantBool::get(!B2->getValue()));
780 Node *N1 = IG.getOrInsertNode(V1),
781 *N2 = IG.getOrInsertNode(V2);
783 if (N1 == N2) return false;
785 IG.addNotEqual(N1, N2);
793 /// Set V1 less than V2.
794 bool addLess(Value *V1, Value *V2) {
795 if (isLess(V1, V2)) return true;
796 if (isGreaterEqual(V1, V2)) return false;
798 Node *N1 = IG.getOrInsertNode(V1), *N2 = IG.getOrInsertNode(V2);
800 if (N1 == N2) return false;
810 /// Set V1 less than or equal to V2.
811 bool addLessEqual(Value *V1, Value *V2) {
812 if (isLessEqual(V1, V2)) return true;
813 if (V1 == V2) return true;
815 if (isLessEqual(V2, V1))
816 return addEqual(V1, V2);
818 if (isGreater(V1, V2)) return false;
820 Node *N1 = IG.getOrInsertNode(V1),
821 *N2 = IG.getOrInsertNode(V2);
823 if (N1 == N2) return true;
825 IG.addLessEqual(N1, N2);
834 DOUT << "WorkList entry, size: " << WorkList.size() << "\n";
835 while (!WorkList.empty()) {
836 DOUT << "WorkList size: " << WorkList.size() << "\n";
838 Instruction *I = WorkList.front();
839 WorkList.pop_front();
841 Value *Canonical = cIG.canonicalize(I);
842 const Type *Ty = I->getType();
844 //DOUT << "solving: " << *I << "\n";
845 //DEBUG(IG.debug(*cerr.stream()));
847 if (BinaryOperator *BO = dyn_cast<BinaryOperator>(I)) {
848 Value *Op0 = cIG.canonicalize(BO->getOperand(0)),
849 *Op1 = cIG.canonicalize(BO->getOperand(1));
851 ConstantIntegral *CI1 = dyn_cast<ConstantIntegral>(Op0),
852 *CI2 = dyn_cast<ConstantIntegral>(Op1);
855 addEqual(BO, ConstantExpr::get(BO->getOpcode(), CI1, CI2));
857 switch (BO->getOpcode()) {
858 case Instruction::SetEQ:
859 // "seteq int %a, %b" EQ true then %a EQ %b
860 // "seteq int %a, %b" EQ false then %a NE %b
861 if (Canonical == ConstantBool::getTrue())
863 else if (Canonical == ConstantBool::getFalse())
864 addNotEqual(Op0, Op1);
866 // %a EQ %b then "seteq int %a, %b" EQ true
867 // %a NE %b then "seteq int %a, %b" EQ false
868 if (isEqual(Op0, Op1))
869 addEqual(BO, ConstantBool::getTrue());
870 else if (isNotEqual(Op0, Op1))
871 addEqual(BO, ConstantBool::getFalse());
874 case Instruction::SetNE:
875 // "setne int %a, %b" EQ true then %a NE %b
876 // "setne int %a, %b" EQ false then %a EQ %b
877 if (Canonical == ConstantBool::getTrue())
878 addNotEqual(Op0, Op1);
879 else if (Canonical == ConstantBool::getFalse())
882 // %a EQ %b then "setne int %a, %b" EQ false
883 // %a NE %b then "setne int %a, %b" EQ true
884 if (isEqual(Op0, Op1))
885 addEqual(BO, ConstantBool::getFalse());
886 else if (isNotEqual(Op0, Op1))
887 addEqual(BO, ConstantBool::getTrue());
890 case Instruction::SetLT:
891 // "setlt int %a, %b" EQ true then %a LT %b
892 // "setlt int %a, %b" EQ false then %b LE %a
893 if (Canonical == ConstantBool::getTrue())
895 else if (Canonical == ConstantBool::getFalse())
896 addLessEqual(Op1, Op0);
898 // %a LT %b then "setlt int %a, %b" EQ true
899 // %a GE %b then "setlt int %a, %b" EQ false
900 if (isLess(Op0, Op1))
901 addEqual(BO, ConstantBool::getTrue());
902 else if (isGreaterEqual(Op0, Op1))
903 addEqual(BO, ConstantBool::getFalse());
906 case Instruction::SetLE:
907 // "setle int %a, %b" EQ true then %a LE %b
908 // "setle int %a, %b" EQ false then %b LT %a
909 if (Canonical == ConstantBool::getTrue())
910 addLessEqual(Op0, Op1);
911 else if (Canonical == ConstantBool::getFalse())
914 // %a LE %b then "setle int %a, %b" EQ true
915 // %a GT %b then "setle int %a, %b" EQ false
916 if (isLessEqual(Op0, Op1))
917 addEqual(BO, ConstantBool::getTrue());
918 else if (isGreater(Op0, Op1))
919 addEqual(BO, ConstantBool::getFalse());
922 case Instruction::SetGT:
923 // "setgt int %a, %b" EQ true then %b LT %a
924 // "setgt int %a, %b" EQ false then %a LE %b
925 if (Canonical == ConstantBool::getTrue())
927 else if (Canonical == ConstantBool::getFalse())
928 addLessEqual(Op0, Op1);
930 // %a GT %b then "setgt int %a, %b" EQ true
931 // %a LE %b then "setgt int %a, %b" EQ false
932 if (isGreater(Op0, Op1))
933 addEqual(BO, ConstantBool::getTrue());
934 else if (isLessEqual(Op0, Op1))
935 addEqual(BO, ConstantBool::getFalse());
938 case Instruction::SetGE:
939 // "setge int %a, %b" EQ true then %b LE %a
940 // "setge int %a, %b" EQ false then %a LT %b
941 if (Canonical == ConstantBool::getTrue())
942 addLessEqual(Op1, Op0);
943 else if (Canonical == ConstantBool::getFalse())
946 // %a GE %b then "setge int %a, %b" EQ true
947 // %a LT %b then "setlt int %a, %b" EQ false
948 if (isGreaterEqual(Op0, Op1))
949 addEqual(BO, ConstantBool::getTrue());
950 else if (isLess(Op0, Op1))
951 addEqual(BO, ConstantBool::getFalse());
954 case Instruction::And: {
955 // "and int %a, %b" EQ -1 then %a EQ -1 and %b EQ -1
956 // "and bool %a, %b" EQ true then %a EQ true and %b EQ true
957 ConstantIntegral *CI = ConstantIntegral::getAllOnesValue(Ty);
958 if (Canonical == CI) {
963 case Instruction::Or: {
964 // "or int %a, %b" EQ 0 then %a EQ 0 and %b EQ 0
965 // "or bool %a, %b" EQ false then %a EQ false and %b EQ false
966 Constant *Zero = Constant::getNullValue(Ty);
967 if (Canonical == Zero) {
972 case Instruction::Xor: {
973 // "xor bool true, %a" EQ true then %a EQ false
974 // "xor bool true, %a" EQ false then %a EQ true
975 // "xor bool false, %a" EQ true then %a EQ true
976 // "xor bool false, %a" EQ false then %a EQ false
977 // "xor int %c, %a" EQ %c then %a EQ 0
978 // "xor int %c, %a" NE %c then %a NE 0
979 // 1. Repeat all of the above, with order of operands reversed.
980 Value *LHS = Op0, *RHS = Op1;
981 if (!isa<Constant>(LHS)) std::swap(LHS, RHS);
983 if (ConstantBool *CB = dyn_cast<ConstantBool>(Canonical)) {
984 if (ConstantBool *A = dyn_cast<ConstantBool>(LHS))
985 addEqual(RHS, ConstantBool::get(A->getValue() ^
988 if (Canonical == LHS) {
989 if (isa<ConstantIntegral>(Canonical))
990 addEqual(RHS, Constant::getNullValue(Ty));
991 } else if (isNotEqual(LHS, Canonical)) {
992 addNotEqual(RHS, Constant::getNullValue(Ty));
999 // "%x = add int %y, %z" and %x EQ %y then %z EQ 0
1000 // "%x = mul int %y, %z" and %x EQ %y then %z EQ 1
1001 // 1. Repeat all of the above, with order of operands reversed.
1002 // "%x = fdiv float %y, %z" and %x EQ %y then %z EQ 1
1003 Value *Known = Op0, *Unknown = Op1;
1004 if (Known != BO) std::swap(Known, Unknown);
1006 switch (BO->getOpcode()) {
1008 case Instruction::Xor:
1009 case Instruction::Or:
1010 case Instruction::Add:
1011 case Instruction::Sub:
1012 if (!Ty->isFloatingPoint())
1013 addEqual(Unknown, Constant::getNullValue(Ty));
1015 case Instruction::UDiv:
1016 case Instruction::SDiv:
1017 case Instruction::FDiv:
1018 if (Unknown == Op0) break; // otherwise, fallthrough
1019 case Instruction::And:
1020 case Instruction::Mul:
1021 Constant *One = NULL;
1022 if (isa<ConstantInt>(Unknown))
1023 One = ConstantInt::get(Ty, 1);
1024 else if (isa<ConstantFP>(Unknown))
1025 One = ConstantFP::get(Ty, 1);
1026 else if (isa<ConstantBool>(Unknown))
1027 One = ConstantBool::getTrue();
1029 if (One) addEqual(Unknown, One);
1033 } else if (SelectInst *SI = dyn_cast<SelectInst>(I)) {
1034 // Given: "%a = select bool %x, int %b, int %c"
1035 // %a EQ %b then %x EQ true
1036 // %a EQ %c then %x EQ false
1037 if (isEqual(I, SI->getTrueValue()) ||
1038 isNotEqual(I, SI->getFalseValue()))
1039 addEqual(SI->getCondition(), ConstantBool::getTrue());
1040 else if (isEqual(I, SI->getFalseValue()) ||
1041 isNotEqual(I, SI->getTrueValue()))
1042 addEqual(SI->getCondition(), ConstantBool::getFalse());
1044 // %x EQ true then %a EQ %b
1045 // %x EQ false then %a NE %b
1046 if (isEqual(SI->getCondition(), ConstantBool::getTrue()))
1047 addEqual(SI, SI->getTrueValue());
1048 else if (isEqual(SI->getCondition(), ConstantBool::getFalse()))
1049 addEqual(SI, SI->getFalseValue());
1055 /// PredicateSimplifier - This class is a simplifier that replaces
1056 /// one equivalent variable with another. It also tracks what
1057 /// can't be equal and will solve setcc instructions when possible.
1058 /// @brief Root of the predicate simplifier optimization.
1059 class VISIBILITY_HIDDEN PredicateSimplifier : public FunctionPass {
1066 BasicBlock *ToVisit;
1067 InequalityGraph *IG;
1069 State(BasicBlock *BB, InequalityGraph *IG) : ToVisit(BB), IG(IG) {}
1072 std::vector<State> WorkList;
1075 bool runOnFunction(Function &F);
1077 virtual void getAnalysisUsage(AnalysisUsage &AU) const {
1078 AU.addRequiredID(BreakCriticalEdgesID);
1079 AU.addRequired<DominatorTree>();
1080 AU.addRequired<ETForest>();
1081 AU.setPreservesCFG();
1082 AU.addPreservedID(BreakCriticalEdgesID);
1086 /// Forwards - Adds new properties into PropertySet and uses them to
1087 /// simplify instructions. Because new properties sometimes apply to
1088 /// a transition from one BasicBlock to another, this will use the
1089 /// PredicateSimplifier::proceedToSuccessor(s) interface to enter the
1090 /// basic block with the new PropertySet.
1091 /// @brief Performs abstract execution of the program.
1092 class VISIBILITY_HIDDEN Forwards : public InstVisitor<Forwards> {
1093 friend class InstVisitor<Forwards>;
1094 PredicateSimplifier *PS;
1097 InequalityGraph &IG;
1099 Forwards(PredicateSimplifier *PS, InequalityGraph &IG)
1102 void visitTerminatorInst(TerminatorInst &TI);
1103 void visitBranchInst(BranchInst &BI);
1104 void visitSwitchInst(SwitchInst &SI);
1106 void visitAllocaInst(AllocaInst &AI);
1107 void visitLoadInst(LoadInst &LI);
1108 void visitStoreInst(StoreInst &SI);
1110 void visitBinaryOperator(BinaryOperator &BO);
1113 // Used by terminator instructions to proceed from the current basic
1114 // block to the next. Verifies that "current" dominates "next",
1115 // then calls visitBasicBlock.
1116 void proceedToSuccessors(const InequalityGraph &IG, BasicBlock *BBCurrent) {
1117 DominatorTree::Node *Current = DT->getNode(BBCurrent);
1118 for (DominatorTree::Node::iterator I = Current->begin(),
1119 E = Current->end(); I != E; ++I) {
1120 //visitBasicBlock((*I)->getBlock(), IG);
1121 WorkList.push_back(State((*I)->getBlock(), new InequalityGraph(IG)));
1125 void proceedToSuccessor(InequalityGraph *NextIG, BasicBlock *Next) {
1126 //visitBasicBlock(Next, NextIG);
1127 WorkList.push_back(State(Next, NextIG));
1130 // Visits each instruction in the basic block.
1131 void visitBasicBlock(BasicBlock *BB, InequalityGraph &IG) {
1132 DOUT << "Entering Basic Block: " << BB->getName() << "\n";
1133 for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) {
1134 visitInstruction(I++, IG);
1138 // Tries to simplify each Instruction and add new properties to
1140 void visitInstruction(Instruction *I, InequalityGraph &IG) {
1141 DOUT << "Considering instruction " << *I << "\n";
1142 DEBUG(IG.debug(*cerr.stream()));
1144 // Sometimes instructions are made dead due to earlier analysis.
1145 if (isInstructionTriviallyDead(I)) {
1146 I->eraseFromParent();
1150 // Try to replace the whole instruction.
1151 Value *V = IG.canonicalize(I);
1155 DOUT << "Removing " << *I << ", replacing with " << *V << "\n";
1157 I->replaceAllUsesWith(V);
1158 I->eraseFromParent();
1162 // Try to substitute operands.
1163 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
1164 Value *Oper = I->getOperand(i);
1165 Value *V = IG.canonicalize(Oper);
1169 DOUT << "Resolving " << *I;
1170 I->setOperand(i, V);
1171 DOUT << " into " << *I;
1175 //DOUT << "push (%" << I->getParent()->getName() << ")\n";
1176 Forwards visit(this, IG);
1178 //DOUT << "pop (%" << I->getParent()->getName() << ")\n";
1182 bool PredicateSimplifier::runOnFunction(Function &F) {
1183 DT = &getAnalysis<DominatorTree>();
1184 Forest = &getAnalysis<ETForest>();
1186 DOUT << "Entering Function: " << F.getName() << "\n";
1189 WorkList.push_back(State(DT->getRoot(), new InequalityGraph()));
1192 State S = WorkList.back();
1193 WorkList.pop_back();
1194 visitBasicBlock(S.ToVisit, *S.IG);
1196 } while (!WorkList.empty());
1198 //DEBUG(F.viewCFG());
1203 void PredicateSimplifier::Forwards::visitTerminatorInst(TerminatorInst &TI) {
1204 PS->proceedToSuccessors(IG, TI.getParent());
1207 void PredicateSimplifier::Forwards::visitBranchInst(BranchInst &BI) {
1208 BasicBlock *BB = BI.getParent();
1210 if (BI.isUnconditional()) {
1211 PS->proceedToSuccessors(IG, BB);
1215 Value *Condition = BI.getCondition();
1216 BasicBlock *TrueDest = BI.getSuccessor(0),
1217 *FalseDest = BI.getSuccessor(1);
1219 if (isa<ConstantBool>(Condition) || TrueDest == FalseDest) {
1220 PS->proceedToSuccessors(IG, BB);
1224 DominatorTree::Node *Node = PS->DT->getNode(BB);
1225 for (DominatorTree::Node::iterator I = Node->begin(), E = Node->end();
1227 BasicBlock *Dest = (*I)->getBlock();
1228 InequalityGraph *DestProperties = new InequalityGraph(IG);
1229 VRPSolver Solver(*DestProperties, PS->Forest, Dest);
1231 if (Dest == TrueDest) {
1232 DOUT << "(" << BB->getName() << ") true set:\n";
1233 if (!Solver.addEqual(ConstantBool::getTrue(), Condition)) continue;
1235 DEBUG(DestProperties->debug(*cerr.stream()));
1236 } else if (Dest == FalseDest) {
1237 DOUT << "(" << BB->getName() << ") false set:\n";
1238 if (!Solver.addEqual(ConstantBool::getFalse(), Condition)) continue;
1240 DEBUG(DestProperties->debug(*cerr.stream()));
1243 PS->proceedToSuccessor(DestProperties, Dest);
1247 void PredicateSimplifier::Forwards::visitSwitchInst(SwitchInst &SI) {
1248 Value *Condition = SI.getCondition();
1250 // Set the EQProperty in each of the cases BBs, and the NEProperties
1251 // in the default BB.
1252 // InequalityGraph DefaultProperties(IG);
1254 DominatorTree::Node *Node = PS->DT->getNode(SI.getParent());
1255 for (DominatorTree::Node::iterator I = Node->begin(), E = Node->end();
1257 BasicBlock *BB = (*I)->getBlock();
1259 InequalityGraph *BBProperties = new InequalityGraph(IG);
1260 VRPSolver Solver(*BBProperties, PS->Forest, BB);
1261 if (BB == SI.getDefaultDest()) {
1262 for (unsigned i = 1, e = SI.getNumCases(); i < e; ++i)
1263 if (SI.getSuccessor(i) != BB)
1264 if (!Solver.addNotEqual(Condition, SI.getCaseValue(i))) continue;
1266 } else if (ConstantInt *CI = SI.findCaseDest(BB)) {
1267 if (!Solver.addEqual(Condition, CI)) continue;
1270 PS->proceedToSuccessor(BBProperties, BB);
1274 void PredicateSimplifier::Forwards::visitAllocaInst(AllocaInst &AI) {
1275 VRPSolver VRP(IG, PS->Forest, AI.getParent());
1276 VRP.addNotEqual(Constant::getNullValue(AI.getType()), &AI);
1280 void PredicateSimplifier::Forwards::visitLoadInst(LoadInst &LI) {
1281 Value *Ptr = LI.getPointerOperand();
1282 // avoid "load uint* null" -> null NE null.
1283 if (isa<Constant>(Ptr)) return;
1285 VRPSolver VRP(IG, PS->Forest, LI.getParent());
1286 VRP.addNotEqual(Constant::getNullValue(Ptr->getType()), Ptr);
1290 void PredicateSimplifier::Forwards::visitStoreInst(StoreInst &SI) {
1291 Value *Ptr = SI.getPointerOperand();
1292 if (isa<Constant>(Ptr)) return;
1294 VRPSolver VRP(IG, PS->Forest, SI.getParent());
1295 VRP.addNotEqual(Constant::getNullValue(Ptr->getType()), Ptr);
1299 void PredicateSimplifier::Forwards::visitBinaryOperator(BinaryOperator &BO) {
1300 Instruction::BinaryOps ops = BO.getOpcode();
1303 case Instruction::URem:
1304 case Instruction::SRem:
1305 case Instruction::FRem:
1306 case Instruction::UDiv:
1307 case Instruction::SDiv:
1308 case Instruction::FDiv: {
1309 Value *Divisor = BO.getOperand(1);
1310 VRPSolver VRP(IG, PS->Forest, BO.getParent());
1311 VRP.addNotEqual(Constant::getNullValue(Divisor->getType()), Divisor);
1321 RegisterPass<PredicateSimplifier> X("predsimplify",
1322 "Predicate Simplifier");
1325 FunctionPass *llvm::createPredicateSimplifierPass() {
1326 return new PredicateSimplifier();