cache results of operator*
[oota-llvm.git] / lib / Transforms / Scalar / LoopStrengthReduce.cpp
index 94df9557608fa9cb98c7a9e12e9e5d6ec954e4c2..a250a88c994730c0e498932ce5f3b9f2c3b786b3 100644 (file)
@@ -392,12 +392,13 @@ static bool isAddSExtable(const SCEVAddExpr *A, ScalarEvolution &SE) {
   return isa<SCEVAddExpr>(SE.getSignExtendExpr(A, WideTy));
 }
 
-/// isMulSExtable - Return true if the given add can be sign-extended
+/// isMulSExtable - Return true if the given mul can be sign-extended
 /// without changing its value.
-static bool isMulSExtable(const SCEVMulExpr *A, ScalarEvolution &SE) {
+static bool isMulSExtable(const SCEVMulExpr *M, ScalarEvolution &SE) {
   const Type *WideTy =
-    IntegerType::get(SE.getContext(), SE.getTypeSizeInBits(A->getType()) + 1);
-  return isa<SCEVMulExpr>(SE.getSignExtendExpr(A, WideTy));
+    IntegerType::get(SE.getContext(),
+                     SE.getTypeSizeInBits(M->getType()) * M->getNumOperands());
+  return isa<SCEVMulExpr>(SE.getSignExtendExpr(M, WideTy));
 }
 
 /// getExactSDiv - Return an expression for LHS /s RHS, if it can be determined
@@ -413,20 +414,28 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS,
   if (LHS == RHS)
     return SE.getConstant(LHS->getType(), 1);
 
-  // Handle x /s -1 as x * -1, to give ScalarEvolution a chance to do some
-  // folding.
-  if (RHS->isAllOnesValue())
-    return SE.getMulExpr(LHS, RHS);
+  // Handle a few RHS special cases.
+  const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS);
+  if (RC) {
+    const APInt &RA = RC->getValue()->getValue();
+    // Handle x /s -1 as x * -1, to give ScalarEvolution a chance to do
+    // some folding.
+    if (RA.isAllOnesValue())
+      return SE.getMulExpr(LHS, RC);
+    // Handle x /s 1 as x.
+    if (RA == 1)
+      return LHS;
+  }
 
   // Check for a division of a constant by a constant.
   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(LHS)) {
-    const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS);
     if (!RC)
       return 0;
-    if (C->getValue()->getValue().srem(RC->getValue()->getValue()) != 0)
+    const APInt &LA = C->getValue()->getValue();
+    const APInt &RA = RC->getValue()->getValue();
+    if (LA.srem(RA) != 0)
       return 0;
-    return SE.getConstant(C->getValue()->getValue()
-               .sdiv(RC->getValue()->getValue()));
+    return SE.getConstant(LA.sdiv(RA));
   }
 
   // Distribute the sdiv over addrec operands, if the addrec doesn't overflow.
@@ -440,6 +449,7 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS,
       if (!Step) return 0;
       return SE.getAddRecExpr(Start, Step, AR->getLoop());
     }
+    return 0;
   }
 
   // Distribute the sdiv over add operands, if the add doesn't overflow.
@@ -455,10 +465,11 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS,
       }
       return SE.getAddExpr(Ops);
     }
+    return 0;
   }
 
   // Check for a multiply operand that we can pull RHS out of.
-  if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS))
+  if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS)) {
     if (IgnoreSignificantBits || isMulSExtable(Mul, SE)) {
       SmallVector<const SCEV *, 4> Ops;
       bool Found = false;
@@ -475,6 +486,8 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS,
       }
       return Found ? SE.getMulExpr(Ops) : 0;
     }
+    return 0;
+  }
 
   // Otherwise we don't know.
   return 0;
@@ -546,7 +559,7 @@ static bool isAddressUse(Instruction *Inst, Value *OperandVal) {
       case Intrinsic::x86_sse2_storeu_pd:
       case Intrinsic::x86_sse2_storeu_dq:
       case Intrinsic::x86_sse2_storel_dq:
-        if (II->getOperand(1) == OperandVal)
+        if (II->getArgOperand(0) == OperandVal)
           isAddress = true;
         break;
     }
@@ -568,7 +581,7 @@ static const Type *getAccessType(const Instruction *Inst) {
     case Intrinsic::x86_sse2_storeu_pd:
     case Intrinsic::x86_sse2_storeu_dq:
     case Intrinsic::x86_sse2_storel_dq:
-      AccessTy = II->getOperand(1)->getType();
+      AccessTy = II->getArgOperand(0)->getType();
       break;
     }
   }
@@ -976,6 +989,8 @@ public:
   void dump() const;
 };
 
+}
+
 /// HasFormula - Test whether this use as a formula which has the same
 /// registers as the given formula.
 bool LSRUse::HasFormulaWithSameRegs(const Formula &F) const {
@@ -1203,6 +1218,32 @@ static bool isAlwaysFoldable(const SCEV *S,
   return isLegalUse(AM, MinOffset, MaxOffset, Kind, AccessTy, TLI);
 }
 
+namespace {
+
+/// UseMapDenseMapInfo - A DenseMapInfo implementation for holding
+/// DenseMaps and DenseSets of pairs of const SCEV* and LSRUse::Kind.
+struct UseMapDenseMapInfo {
+  static std::pair<const SCEV *, LSRUse::KindType> getEmptyKey() {
+    return std::make_pair(reinterpret_cast<const SCEV *>(-1), LSRUse::Basic);
+  }
+
+  static std::pair<const SCEV *, LSRUse::KindType> getTombstoneKey() {
+    return std::make_pair(reinterpret_cast<const SCEV *>(-2), LSRUse::Basic);
+  }
+
+  static unsigned
+  getHashValue(const std::pair<const SCEV *, LSRUse::KindType> &V) {
+    unsigned Result = DenseMapInfo<const SCEV *>::getHashValue(V.first);
+    Result ^= DenseMapInfo<unsigned>::getHashValue(unsigned(V.second));
+    return Result;
+  }
+
+  static bool isEqual(const std::pair<const SCEV *, LSRUse::KindType> &LHS,
+                      const std::pair<const SCEV *, LSRUse::KindType> &RHS) {
+    return LHS == RHS;
+  }
+};
+
 /// FormulaSorter - This class implements an ordering for formulae which sorts
 /// the by their standalone cost.
 class FormulaSorter {
@@ -1275,7 +1316,9 @@ class LSRInstance {
   }
 
   // Support for sharing of LSRUses between LSRFixups.
-  typedef DenseMap<const SCEV *, size_t> UseMapTy;
+  typedef DenseMap<std::pair<const SCEV *, LSRUse::KindType>,
+                   size_t,
+                   UseMapDenseMapInfo> UseMapTy;
   UseMapTy UseMap;
 
   bool reconcileNewOffset(LSRUse &LU, int64_t NewOffset, bool HasBaseReg,
@@ -1613,8 +1656,11 @@ ICmpInst *LSRInstance::OptimizeMax(ICmpInst *Cond, IVStrideUse* &CondUse) {
     NewRHS = Sel->getOperand(1);
   else if (SE.getSCEV(Sel->getOperand(2)) == MaxRHS)
     NewRHS = Sel->getOperand(2);
+  else if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(MaxRHS))
+    NewRHS = SU->getValue();
   else
-    llvm_unreachable("Max doesn't match expected pattern!");
+    // Max doesn't match expected pattern.
+    return Cond;
 
   // Determine the new comparison opcode. It may be signed or unsigned,
   // and the original comparison may be either equality or inequality.
@@ -1805,6 +1851,8 @@ LSRInstance::reconcileNewOffset(LSRUse &LU, int64_t NewOffset, bool HasBaseReg,
     NewMaxOffset = NewOffset;
   }
   // Check for a mismatched access type, and fall back conservatively as needed.
+  // TODO: Be less conservative when the type is similar and can use the same
+  // addressing modes.
   if (Kind == LSRUse::Address && AccessTy != LU.AccessTy)
     NewAccessTy = Type::getVoidTy(AccessTy->getContext());
 
@@ -1833,7 +1881,7 @@ LSRInstance::getUse(const SCEV *&Expr,
   }
 
   std::pair<UseMapTy::iterator, bool> P =
-    UseMap.insert(std::make_pair(Expr, 0));
+    UseMap.insert(std::make_pair(std::make_pair(Expr, Kind), 0));
   if (!P.second) {
     // A use already existed with this base.
     size_t LUIdx = P.first->second;
@@ -1919,7 +1967,7 @@ void LSRInstance::CollectInterestingTypesAndFactors() {
         Strides.insert(AR->getStepRecurrence(SE));
         Worklist.push_back(AR->getStart());
       } else if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
-        Worklist.insert(Worklist.end(), Add->op_begin(), Add->op_end());
+        Worklist.append(Add->op_begin(), Add->op_end());
       }
     } while (!Worklist.empty());
   }
@@ -2086,7 +2134,7 @@ LSRInstance::CollectLoopInvariantFixupsAndFormulae() {
     const SCEV *S = Worklist.pop_back_val();
 
     if (const SCEVNAryExpr *N = dyn_cast<SCEVNAryExpr>(S))
-      Worklist.insert(Worklist.end(), N->op_begin(), N->op_end());
+      Worklist.append(N->op_begin(), N->op_end());
     else if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(S))
       Worklist.push_back(C->getOperand());
     else if (const SCEVUDivExpr *D = dyn_cast<SCEVUDivExpr>(S)) {
@@ -2159,20 +2207,23 @@ LSRInstance::CollectLoopInvariantFixupsAndFormulae() {
 /// separate registers. If C is non-null, multiply each subexpression by C.
 static void CollectSubexprs(const SCEV *S, const SCEVConstant *C,
                             SmallVectorImpl<const SCEV *> &Ops,
+                            SmallVectorImpl<const SCEV *> &UninterestingOps,
+                            const Loop *L,
                             ScalarEvolution &SE) {
   if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
     // Break out add operands.
     for (SCEVAddExpr::op_iterator I = Add->op_begin(), E = Add->op_end();
          I != E; ++I)
-      CollectSubexprs(*I, C, Ops, SE);
+      CollectSubexprs(*I, C, Ops, UninterestingOps, L, SE);
     return;
   } else if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
     // Split a non-zero base out of an addrec.
     if (!AR->getStart()->isZero()) {
       CollectSubexprs(SE.getAddRecExpr(SE.getConstant(AR->getType(), 0),
                                        AR->getStepRecurrence(SE),
-                                       AR->getLoop()), C, Ops, SE);
-      CollectSubexprs(AR->getStart(), C, Ops, SE);
+                                       AR->getLoop()),
+                      C, Ops, UninterestingOps, L, SE);
+      CollectSubexprs(AR->getStart(), C, Ops, UninterestingOps, L, SE);
       return;
     }
   } else if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
@@ -2182,13 +2233,17 @@ static void CollectSubexprs(const SCEV *S, const SCEVConstant *C,
             dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
         CollectSubexprs(Mul->getOperand(1),
                         C ? cast<SCEVConstant>(SE.getMulExpr(C, Op0)) : Op0,
-                        Ops, SE);
+                        Ops, UninterestingOps, L, SE);
         return;
       }
   }
 
-  // Otherwise use the value itself.
-  Ops.push_back(C ? SE.getMulExpr(C, S) : S);
+  // Otherwise use the value itself. Loop-variant "unknown" values are
+  // uninteresting; we won't be able to do anything meaningful with them.
+  if (!C && isa<SCEVUnknown>(S) && !S->isLoopInvariant(L))
+    UninterestingOps.push_back(S);
+  else
+    Ops.push_back(C ? SE.getMulExpr(C, S) : S);
 }
 
 /// GenerateReassociations - Split out subexpressions from adds and the bases of
@@ -2202,8 +2257,15 @@ void LSRInstance::GenerateReassociations(LSRUse &LU, unsigned LUIdx,
   for (size_t i = 0, e = Base.BaseRegs.size(); i != e; ++i) {
     const SCEV *BaseReg = Base.BaseRegs[i];
 
-    SmallVector<const SCEV *, 8> AddOps;
-    CollectSubexprs(BaseReg, 0, AddOps, SE);
+    SmallVector<const SCEV *, 8> AddOps, UninterestingAddOps;
+    CollectSubexprs(BaseReg, 0, AddOps, UninterestingAddOps, L, SE);
+
+    // Add any uninteresting values as one register, as we won't be able to
+    // form any interesting reassociation opportunities with them. They'll
+    // just have to be added inside the loop no matter what we do.
+    if (!UninterestingAddOps.empty())
+      AddOps.push_back(SE.getAddExpr(UninterestingAddOps));
+
     if (AddOps.size() == 1) continue;
 
     for (SmallVectorImpl<const SCEV *>::const_iterator J = AddOps.begin(),
@@ -2216,11 +2278,10 @@ void LSRInstance::GenerateReassociations(LSRUse &LU, unsigned LUIdx,
         continue;
 
       // Collect all operands except *J.
-      SmallVector<const SCEV *, 8> InnerAddOps;
-      for (SmallVectorImpl<const SCEV *>::const_iterator K = AddOps.begin(),
-           KE = AddOps.end(); K != KE; ++K)
-        if (K != J)
-          InnerAddOps.push_back(*K);
+      SmallVector<const SCEV *, 8> InnerAddOps
+        (         ((const SmallVector<const SCEV *, 8> &)AddOps).begin(), J);
+      InnerAddOps.append
+        (next(J), ((const SmallVector<const SCEV *, 8> &)AddOps).end());
 
       // Don't leave just a constant behind in a register if the constant could
       // be folded into an immediate field.
@@ -2354,13 +2415,12 @@ void LSRInstance::GenerateICmpZeroScales(LSRUse &LU, unsigned LUIdx,
   for (SmallSetVector<int64_t, 8>::const_iterator
        I = Factors.begin(), E = Factors.end(); I != E; ++I) {
     int64_t Factor = *I;
-    Formula F = Base;
 
     // Check that the multiplication doesn't overflow.
-    if (F.AM.BaseOffs == INT64_MIN && Factor == -1)
+    if (Base.AM.BaseOffs == INT64_MIN && Factor == -1)
       continue;
-    F.AM.BaseOffs = (uint64_t)Base.AM.BaseOffs * Factor;
-    if (F.AM.BaseOffs / Factor != Base.AM.BaseOffs)
+    int64_t NewBaseOffs = (uint64_t)Base.AM.BaseOffs * Factor;
+    if (NewBaseOffs / Factor != Base.AM.BaseOffs)
       continue;
 
     // Check that multiplying with the use offset doesn't overflow.
@@ -2371,6 +2431,9 @@ void LSRInstance::GenerateICmpZeroScales(LSRUse &LU, unsigned LUIdx,
     if (Offset / Factor != LU.MinOffset)
       continue;
 
+    Formula F = Base;
+    F.AM.BaseOffs = NewBaseOffs;
+
     // Check that this scale is legal.
     if (!isLegalUse(F.AM, Offset, Offset, LU.Kind, LU.AccessTy, TLI))
       continue;