[SCEV] Refactor out a createNodeForSelect
authorSanjoy Das <sanjoy@playingwithpointers.com>
Fri, 2 Oct 2015 19:39:59 +0000 (19:39 +0000)
committerSanjoy Das <sanjoy@playingwithpointers.com>
Fri, 2 Oct 2015 19:39:59 +0000 (19:39 +0000)
Summary:
We will shortly re-use this for select-like br-phi pairs.

Reviewers: atrick, joker-eph, joker.eph

Subscribers: sanjoy, llvm-commits

Differential Revision: http://reviews.llvm.org/D13377

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@249177 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolution.h
lib/Analysis/ScalarEvolution.cpp

index e853d2ad63c17419ae13efb8caba0e63e8589e08..36e2477528c095ca59476b1bb46d0090f30d2997 100644 (file)
@@ -415,6 +415,13 @@ namespace llvm {
     /// Provide the special handling we need to analyze PHI SCEVs.
     const SCEV *createNodeForPHI(PHINode *PN);
 
+    /// Provide special handling for a select-like instruction (currently this
+    /// is either a select instruction or a phi node).  \p I is the instruction
+    /// being processed, and it is assumed equivalent to "Cond ? TrueVal :
+    /// FalseVal".
+    const SCEV *createNodeForSelect(Instruction *I, Value *Cond, Value *TrueVal,
+                                    Value *FalseVal);
+
     /// Provide the special handling we need to analyze GEP SCEVs.
     const SCEV *createNodeForGEP(GEPOperator *GEP);
 
index 0d56183c258f28f590dd8edc4717606ee89fb77e..cef0ce506add2ab457a71b7da52042cc5bdbee42 100644 (file)
@@ -3756,6 +3756,99 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
   return getUnknown(PN);
 }
 
+const SCEV *ScalarEvolution::createNodeForSelect(Instruction *I, Value *Cond,
+                                                 Value *TrueVal,
+                                                 Value *FalseVal) {
+  // Try to match some simple smax or umax patterns.
+  auto *ICI = dyn_cast<ICmpInst>(Cond);
+  if (!ICI)
+    return getUnknown(I);
+
+  Value *LHS = ICI->getOperand(0);
+  Value *RHS = ICI->getOperand(1);
+
+  switch (ICI->getPredicate()) {
+  case ICmpInst::ICMP_SLT:
+  case ICmpInst::ICMP_SLE:
+    std::swap(LHS, RHS);
+  // fall through
+  case ICmpInst::ICMP_SGT:
+  case ICmpInst::ICMP_SGE:
+    // a >s b ? a+x : b+x  ->  smax(a, b)+x
+    // a >s b ? b+x : a+x  ->  smin(a, b)+x
+    if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
+      const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), I->getType());
+      const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), I->getType());
+      const SCEV *LA = getSCEV(TrueVal);
+      const SCEV *RA = getSCEV(FalseVal);
+      const SCEV *LDiff = getMinusSCEV(LA, LS);
+      const SCEV *RDiff = getMinusSCEV(RA, RS);
+      if (LDiff == RDiff)
+        return getAddExpr(getSMaxExpr(LS, RS), LDiff);
+      LDiff = getMinusSCEV(LA, RS);
+      RDiff = getMinusSCEV(RA, LS);
+      if (LDiff == RDiff)
+        return getAddExpr(getSMinExpr(LS, RS), LDiff);
+    }
+    break;
+  case ICmpInst::ICMP_ULT:
+  case ICmpInst::ICMP_ULE:
+    std::swap(LHS, RHS);
+  // fall through
+  case ICmpInst::ICMP_UGT:
+  case ICmpInst::ICMP_UGE:
+    // a >u b ? a+x : b+x  ->  umax(a, b)+x
+    // a >u b ? b+x : a+x  ->  umin(a, b)+x
+    if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
+      const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
+      const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), I->getType());
+      const SCEV *LA = getSCEV(TrueVal);
+      const SCEV *RA = getSCEV(FalseVal);
+      const SCEV *LDiff = getMinusSCEV(LA, LS);
+      const SCEV *RDiff = getMinusSCEV(RA, RS);
+      if (LDiff == RDiff)
+        return getAddExpr(getUMaxExpr(LS, RS), LDiff);
+      LDiff = getMinusSCEV(LA, RS);
+      RDiff = getMinusSCEV(RA, LS);
+      if (LDiff == RDiff)
+        return getAddExpr(getUMinExpr(LS, RS), LDiff);
+    }
+    break;
+  case ICmpInst::ICMP_NE:
+    // n != 0 ? n+x : 1+x  ->  umax(n, 1)+x
+    if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
+        isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
+      const SCEV *One = getOne(I->getType());
+      const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
+      const SCEV *LA = getSCEV(TrueVal);
+      const SCEV *RA = getSCEV(FalseVal);
+      const SCEV *LDiff = getMinusSCEV(LA, LS);
+      const SCEV *RDiff = getMinusSCEV(RA, One);
+      if (LDiff == RDiff)
+        return getAddExpr(getUMaxExpr(One, LS), LDiff);
+    }
+    break;
+  case ICmpInst::ICMP_EQ:
+    // n == 0 ? 1+x : n+x  ->  umax(n, 1)+x
+    if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
+        isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
+      const SCEV *One = getOne(I->getType());
+      const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
+      const SCEV *LA = getSCEV(TrueVal);
+      const SCEV *RA = getSCEV(FalseVal);
+      const SCEV *LDiff = getMinusSCEV(LA, One);
+      const SCEV *RDiff = getMinusSCEV(RA, LS);
+      if (LDiff == RDiff)
+        return getAddExpr(getUMaxExpr(One, LS), LDiff);
+    }
+    break;
+  default:
+    break;
+  }
+
+  return getUnknown(I);
+}
+
 /// createNodeForGEP - Expand GEP instructions into add and multiply
 /// operations. This allows them to be analyzed by regular SCEV code.
 ///
@@ -4470,94 +4563,13 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
     return createNodeForPHI(cast<PHINode>(U));
 
   case Instruction::Select:
-    // This could be a smax or umax that was lowered earlier.
-    // Try to recover it.
-    if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
-      Value *LHS = ICI->getOperand(0);
-      Value *RHS = ICI->getOperand(1);
-      switch (ICI->getPredicate()) {
-      case ICmpInst::ICMP_SLT:
-      case ICmpInst::ICMP_SLE:
-        std::swap(LHS, RHS);
-        // fall through
-      case ICmpInst::ICMP_SGT:
-      case ICmpInst::ICMP_SGE:
-        // a >s b ? a+x : b+x  ->  smax(a, b)+x
-        // a >s b ? b+x : a+x  ->  smin(a, b)+x
-        if (getTypeSizeInBits(LHS->getType()) <=
-            getTypeSizeInBits(U->getType())) {
-          const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), U->getType());
-          const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), U->getType());
-          const SCEV *LA = getSCEV(U->getOperand(1));
-          const SCEV *RA = getSCEV(U->getOperand(2));
-          const SCEV *LDiff = getMinusSCEV(LA, LS);
-          const SCEV *RDiff = getMinusSCEV(RA, RS);
-          if (LDiff == RDiff)
-            return getAddExpr(getSMaxExpr(LS, RS), LDiff);
-          LDiff = getMinusSCEV(LA, RS);
-          RDiff = getMinusSCEV(RA, LS);
-          if (LDiff == RDiff)
-            return getAddExpr(getSMinExpr(LS, RS), LDiff);
-        }
-        break;
-      case ICmpInst::ICMP_ULT:
-      case ICmpInst::ICMP_ULE:
-        std::swap(LHS, RHS);
-        // fall through
-      case ICmpInst::ICMP_UGT:
-      case ICmpInst::ICMP_UGE:
-        // a >u b ? a+x : b+x  ->  umax(a, b)+x
-        // a >u b ? b+x : a+x  ->  umin(a, b)+x
-        if (getTypeSizeInBits(LHS->getType()) <=
-            getTypeSizeInBits(U->getType())) {
-          const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
-          const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), U->getType());
-          const SCEV *LA = getSCEV(U->getOperand(1));
-          const SCEV *RA = getSCEV(U->getOperand(2));
-          const SCEV *LDiff = getMinusSCEV(LA, LS);
-          const SCEV *RDiff = getMinusSCEV(RA, RS);
-          if (LDiff == RDiff)
-            return getAddExpr(getUMaxExpr(LS, RS), LDiff);
-          LDiff = getMinusSCEV(LA, RS);
-          RDiff = getMinusSCEV(RA, LS);
-          if (LDiff == RDiff)
-            return getAddExpr(getUMinExpr(LS, RS), LDiff);
-        }
-        break;
-      case ICmpInst::ICMP_NE:
-        // n != 0 ? n+x : 1+x  ->  umax(n, 1)+x
-        if (getTypeSizeInBits(LHS->getType()) <=
-                getTypeSizeInBits(U->getType()) &&
-            isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
-          const SCEV *One = getOne(U->getType());
-          const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
-          const SCEV *LA = getSCEV(U->getOperand(1));
-          const SCEV *RA = getSCEV(U->getOperand(2));
-          const SCEV *LDiff = getMinusSCEV(LA, LS);
-          const SCEV *RDiff = getMinusSCEV(RA, One);
-          if (LDiff == RDiff)
-            return getAddExpr(getUMaxExpr(One, LS), LDiff);
-        }
-        break;
-      case ICmpInst::ICMP_EQ:
-        // n == 0 ? 1+x : n+x  ->  umax(n, 1)+x
-        if (getTypeSizeInBits(LHS->getType()) <=
-                getTypeSizeInBits(U->getType()) &&
-            isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
-          const SCEV *One = getOne(U->getType());
-          const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
-          const SCEV *LA = getSCEV(U->getOperand(1));
-          const SCEV *RA = getSCEV(U->getOperand(2));
-          const SCEV *LDiff = getMinusSCEV(LA, One);
-          const SCEV *RDiff = getMinusSCEV(RA, LS);
-          if (LDiff == RDiff)
-            return getAddExpr(getUMaxExpr(One, LS), LDiff);
-        }
-        break;
-      default:
-        break;
-      }
-    }
+    // U can also be a select constant expr, which let fall through.  Since
+    // createNodeForSelect only works for a condition that is an `ICmpInst`, and
+    // constant expressions cannot have instructions as operands, we'd have
+    // returned getUnknown for a select constant expressions anyway.
+    if (isa<Instruction>(U))
+      return createNodeForSelect(cast<Instruction>(U), U->getOperand(0),
+                                 U->getOperand(1), U->getOperand(2));
 
   default: // We cannot analyze this expression.
     break;