Push LLVMContext through the PatternMatch API.
[oota-llvm.git] / lib / Transforms / Scalar / PredicateSimplifier.cpp
index 3723bcbb0a6d9451645f448a3d930f19468c3d6d..24707bd4d86722a6296c7a529359967bc1f7b3bf 100644 (file)
@@ -2,8 +2,8 @@
 //
 //                     The LLVM Compiler Infrastructure
 //
-// This file was developed by Nick Lewycky and is distributed under the
-// University of Illinois Open Source License. See LICENSE.TXT for details.
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
 //
 //===----------------------------------------------------------------------===//
 //
 #include "llvm/Transforms/Utils/Local.h"
 #include <algorithm>
 #include <deque>
-#include <sstream>
 #include <stack>
 using namespace llvm;
 
@@ -111,6 +110,8 @@ STATISTIC(NumSimple      , "Number of simple replacements");
 STATISTIC(NumBlocks      , "Number of blocks marked unreachable");
 STATISTIC(NumSnuggle     , "Number of comparisons snuggled");
 
+static const ConstantRange empty(1, false);
+
 namespace {
   class DomTreeDFS {
   public:
@@ -244,6 +245,7 @@ namespace {
              *Node2 = getNodeForBlock(BB2);
         return Node1 && Node2 && Node1->dominates(Node2);
       }
+      return false; // Not reached
     }
 
   private:
@@ -341,6 +343,7 @@ namespace {
     UGE = UGT | EQ_BIT
   };
 
+#ifndef NDEBUG
   /// validPredicate - determines whether a given value is actually a lattice
   /// value. Only used in assertions or debugging.
   static bool validPredicate(LatticeVal LV) {
@@ -355,6 +358,7 @@ namespace {
         return false;
     }
   }
+#endif
 
   /// reversePredicate - reverse the direction of the inequality
   static LatticeVal reversePredicate(LatticeVal LV) {
@@ -717,7 +721,7 @@ namespace {
 
           if (edge.LV == J->LV)
             return; // This update adds nothing new.
-       }
+        }
 
         if (I != B) {
           // We also have to tighten any edge beneath our update.
@@ -729,7 +733,7 @@ namespace {
             }
             if (K == B) break;
           }
-       }
+        }
 
         // Insert new edge at Subtree if it isn't already there.
         if (I == E || I->To != n || Subtree != I->Subtree)
@@ -922,7 +926,7 @@ namespace {
       void dump(std::ostream &os) const {
         os << "{";
         for (const_iterator I = begin(), E = end(); I != E; ++I) {
-          os << I->second << " (" << I->first->getDFSNumIn() << "), ";
+          os << &I->second << " (" << I->first->getDFSNumIn() << "), ";
         }
         os << "}";
       }
@@ -937,7 +941,6 @@ namespace {
       const_iterator end()   const { return RangeList.end(); }
 
       iterator find(DomTreeDFS::Node *Subtree) {
-        static ConstantRange empty(1, false);
         iterator E = end();
         iterator I = std::lower_bound(begin(), E,
                                       std::make_pair(Subtree, empty), swo);
@@ -947,7 +950,6 @@ namespace {
       }
 
       const_iterator find(DomTreeDFS::Node *Subtree) const {
-        static const ConstantRange empty(1, false);
         const_iterator E = end();
         const_iterator I = std::lower_bound(begin(), E,
                                             std::make_pair(Subtree, empty), swo);
@@ -960,7 +962,6 @@ namespace {
         assert(!CR.isEmptySet() && "Empty ConstantRange.");
         assert(!CR.isSingleElement() && "Refusing to store single element.");
 
-        static ConstantRange empty(1, false);
         iterator E = end();
         iterator I =
             std::lower_bound(begin(), E, std::make_pair(Subtree, empty), swo);
@@ -1112,7 +1113,7 @@ namespace {
       else if (isa<ConstantPointerNull>(V))
         return ConstantRange(APInt::getNullValue(typeToWidth(V->getType())));
       else
-        return typeToWidth(V->getType());
+        return ConstantRange(typeToWidth(V->getType()));
     }
 
     // typeToWidth - returns the number of bits necessary to store a value of
@@ -1120,11 +1121,8 @@ namespace {
     uint32_t typeToWidth(const Type *Ty) const {
       if (TD)
         return TD->getTypeSizeInBits(Ty);
-
-      if (const IntegerType *ITy = dyn_cast<IntegerType>(Ty))
-        return ITy->getBitWidth();
-
-      return 0;
+      else
+        return Ty->getPrimitiveSizeInBits();
     }
 
     static bool isRelatedBy(const ConstantRange &CR1, const ConstantRange &CR2,
@@ -1416,6 +1414,7 @@ namespace {
         if (!Node) return false;
         return Top->dominates(Node);
       }
+      return false; // Not reached
     }
 
     // aboveOrBelow - true if the Instruction either dominates or is dominated
@@ -1511,7 +1510,7 @@ namespace {
       }
 
       // We'd like to allow makeEqual on two values to perform a simple
-      // substitution without every creating nodes in the IG whenever possible.
+      // substitution without creating nodes in the IG whenever possible.
       //
       // The first iteration through this loop operates on V2 before going
       // through the Remove list and operating on those too. If all of the
@@ -1525,12 +1524,12 @@ namespace {
         Instruction *I2 = dyn_cast<Instruction>(R);
         if (I2 && below(I2)) {
           std::vector<Instruction *> ToNotify;
-          for (Value::use_iterator UI = R->use_begin(), UE = R->use_end();
+          for (Value::use_iterator UI = I2->use_begin(), UE = I2->use_end();
                UI != UE;) {
             Use &TheUse = UI.getUse();
             ++UI;
-            if (Instruction *I = dyn_cast<Instruction>(TheUse.getUser()))
-              ToNotify.push_back(I);
+            Instruction *I = cast<Instruction>(TheUse.getUser());
+            ToNotify.push_back(I);
           }
 
           DOUT << "Simply removing " << *I2
@@ -1596,6 +1595,7 @@ namespace {
       if (mergeIGNode) {
         // Create N1.
         if (!n1) n1 = VN.getOrInsertVN(V1, Top);
+        IG.node(n1); // Ensure that IG.Nodes won't get resized
 
         // Migrate relationships from removed nodes to N1.
         for (SetVector<unsigned>::iterator I = Remove.begin(), E = Remove.end();
@@ -1657,10 +1657,9 @@ namespace {
           ++UI;
           Value *V = TheUse.getUser();
           if (!V->use_empty()) {
-            if (Instruction *Inst = dyn_cast<Instruction>(V)) {
-              if (aboveOrBelow(Inst))
-                opsToDef(Inst);
-            }
+            Instruction *Inst = cast<Instruction>(V);
+            if (aboveOrBelow(Inst))
+              opsToDef(Inst);
           }
         }
       }
@@ -1939,6 +1938,7 @@ namespace {
         assert(!Ty->isFPOrFPVector() && "Float in work queue!");
 
         Constant *Zero = Constant::getNullValue(Ty);
+        Constant *One = ConstantInt::get(Ty, 1);
         ConstantInt *AllOnes = ConstantInt::getAllOnesValue(Ty);
 
         switch (Opcode) {
@@ -1946,19 +1946,73 @@ namespace {
           case Instruction::LShr:
           case Instruction::AShr:
           case Instruction::Shl:
+            if (Op1 == Zero) {
+              add(BO, Op0, ICmpInst::ICMP_EQ, NewContext);
+              return;
+            }
+            break;
           case Instruction::Sub:
             if (Op1 == Zero) {
               add(BO, Op0, ICmpInst::ICMP_EQ, NewContext);
               return;
             }
+            if (ConstantInt *CI0 = dyn_cast<ConstantInt>(Op0)) {
+              unsigned n_ci0 = VN.getOrInsertVN(Op1, Top);
+              ConstantRange CR = VR.range(n_ci0, Top);
+              if (!CR.isFullSet()) {
+                CR.subtract(CI0->getValue());
+                unsigned n_bo = VN.getOrInsertVN(BO, Top);
+                VR.applyRange(n_bo, CR, Top, this);
+                return;
+              }
+            }
+            if (ConstantInt *CI1 = dyn_cast<ConstantInt>(Op1)) {
+              unsigned n_ci1 = VN.getOrInsertVN(Op0, Top);
+              ConstantRange CR = VR.range(n_ci1, Top);
+              if (!CR.isFullSet()) {
+                CR.subtract(CI1->getValue());
+                unsigned n_bo = VN.getOrInsertVN(BO, Top);
+                VR.applyRange(n_bo, CR, Top, this);
+                return;
+              }
+            }
             break;
           case Instruction::Or:
             if (Op0 == AllOnes || Op1 == AllOnes) {
               add(BO, AllOnes, ICmpInst::ICMP_EQ, NewContext);
               return;
-            } // fall-through
-          case Instruction::Xor:
+            }
+            if (Op0 == Zero) {
+              add(BO, Op1, ICmpInst::ICMP_EQ, NewContext);
+              return;
+            } else if (Op1 == Zero) {
+              add(BO, Op0, ICmpInst::ICMP_EQ, NewContext);
+              return;
+            }
+            break;
           case Instruction::Add:
+            if (ConstantInt *CI0 = dyn_cast<ConstantInt>(Op0)) {
+              unsigned n_ci0 = VN.getOrInsertVN(Op1, Top);
+              ConstantRange CR = VR.range(n_ci0, Top);
+              if (!CR.isFullSet()) {
+                CR.subtract(-CI0->getValue());
+                unsigned n_bo = VN.getOrInsertVN(BO, Top);
+                VR.applyRange(n_bo, CR, Top, this);
+                return;
+              }
+            }
+            if (ConstantInt *CI1 = dyn_cast<ConstantInt>(Op1)) {
+              unsigned n_ci1 = VN.getOrInsertVN(Op0, Top);
+              ConstantRange CR = VR.range(n_ci1, Top);
+              if (!CR.isFullSet()) {
+                CR.subtract(-CI1->getValue());
+                unsigned n_bo = VN.getOrInsertVN(BO, Top);
+                VR.applyRange(n_bo, CR, Top, this);
+                return;
+              }
+            }
+            // fall-through
+          case Instruction::Xor:
             if (Op0 == Zero) {
               add(BO, Op1, ICmpInst::ICMP_EQ, NewContext);
               return;
@@ -1975,19 +2029,30 @@ namespace {
               add(BO, Op0, ICmpInst::ICMP_EQ, NewContext);
               return;
             }
-            // fall-through
+            if (Op0 == Zero || Op1 == Zero) {
+              add(BO, Zero, ICmpInst::ICMP_EQ, NewContext);
+              return;
+            }
+            break;
           case Instruction::Mul:
             if (Op0 == Zero || Op1 == Zero) {
               add(BO, Zero, ICmpInst::ICMP_EQ, NewContext);
               return;
             }
+            if (Op0 == One) {
+              add(BO, Op1, ICmpInst::ICMP_EQ, NewContext);
+              return;
+            } else if (Op1 == One) {
+              add(BO, Op0, ICmpInst::ICMP_EQ, NewContext);
+              return;
+            }
             break;
         }
 
         // "%x = add i32 %y, %z" and %x EQ %y then %z EQ 0
         // "%x = add i32 %y, %z" and %x EQ %z then %y EQ 0
         // "%x = shl i32 %y, %z" and %x EQ %y and %y NE 0 then %z EQ 0
-        // "%x = udiv i32 %y, %z" and %x EQ %y then %z EQ 1
+        // "%x = udiv i32 %y, %z" and %x EQ %y and %y NE 0 then %z EQ 1
 
         Value *Known = Op0, *Unknown = Op1,
               *TheBO = VN.canonicalize(BO, Top);
@@ -2010,10 +2075,8 @@ namespace {
             case Instruction::UDiv:
             case Instruction::SDiv:
               if (Unknown == Op1) break;
-              if (isRelatedBy(Known, Zero, ICmpInst::ICMP_NE)) {
-                Constant *One = ConstantInt::get(Ty, 1);
+              if (isRelatedBy(Known, Zero, ICmpInst::ICMP_NE))
                 add(Unknown, One, ICmpInst::ICMP_EQ, NewContext);
-              }
               break;
           }
         }
@@ -2197,10 +2260,9 @@ namespace {
                    UE = O.LHS->use_end(); UI != UE;) {
                 Use &TheUse = UI.getUse();
                 ++UI;
-                if (Instruction *I = dyn_cast<Instruction>(TheUse.getUser())) {
-                  if (aboveOrBelow(I))
-                    opsToDef(I);
-                }
+                Instruction *I = cast<Instruction>(TheUse.getUser());
+                if (aboveOrBelow(I))
+                  opsToDef(I);
               }
             }
             if (Instruction *I2 = dyn_cast<Instruction>(O.RHS)) {
@@ -2212,10 +2274,9 @@ namespace {
                    UE = O.RHS->use_end(); UI != UE;) {
                 Use &TheUse = UI.getUse();
                 ++UI;
-                if (Instruction *I = dyn_cast<Instruction>(TheUse.getUser())) {
-                  if (aboveOrBelow(I))
-                    opsToDef(I);
-                }
+                Instruction *I = cast<Instruction>(TheUse.getUser());
+                if (aboveOrBelow(I))
+                  opsToDef(I);
               }
             }
           }
@@ -2250,7 +2311,7 @@ namespace {
 
   public:
     static char ID; // Pass identification, replacement for typeid
-    PredicateSimplifier() : FunctionPass((intptr_t)&ID) {}
+    PredicateSimplifier() : FunctionPass(&ID) {}
 
     bool runOnFunction(Function &F);
 
@@ -2402,6 +2463,7 @@ namespace {
     delete DTDFS;
     delete VR;
     delete IG;
+    delete VN;
 
     modified |= UB.kill();
 
@@ -2489,7 +2551,7 @@ namespace {
 
   void PredicateSimplifier::Forwards::visitLoadInst(LoadInst &LI) {
     Value *Ptr = LI.getPointerOperand();
-    // avoid "load uint* null" -> null NE null.
+    // avoid "load i8* null" -> null NE null.
     if (isa<Constant>(Ptr)) return;
 
     VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &LI);
@@ -2625,14 +2687,14 @@ namespace {
           if (!Op1->getValue().isAllOnesValue())
             NextVal = ConstantInt::get(Op1->getValue()+1);
          break;
-
       }
+
       if (NextVal) {
         VRPSolver VRP(VN, IG, UB, VR, PS->DTDFS, PS->modified, &IC);
         if (VRP.isRelatedBy(IC.getOperand(0), NextVal,
                             ICmpInst::getInversePredicate(Pred))) {
-          ICmpInst *NewIC = new ICmpInst(ICmpInst::ICMP_EQ, IC.getOperand(0),
-                                         NextVal, "", &IC);
+          ICmpInst *NewIC = new ICmpInst(&IC, ICmpInst::ICMP_EQ, 
+                                         IC.getOperand(0), NextVal, "");
           NewIC->takeName(&IC);
           IC.replaceAllUsesWith(NewIC);
 
@@ -2648,12 +2710,12 @@ namespace {
       }
     }
   }
-
-  char PredicateSimplifier::ID = 0;
-  RegisterPass<PredicateSimplifier> X("predsimplify",
-                                      "Predicate Simplifier");
 }
 
+char PredicateSimplifier::ID = 0;
+static RegisterPass<PredicateSimplifier>
+X("predsimplify", "Predicate Simplifier");
+
 FunctionPass *llvm::createPredicateSimplifierPass() {
   return new PredicateSimplifier();
 }