InstCombine: Respect recursion depth in visitUDivOperand
[oota-llvm.git] / lib / Transforms / InstCombine / InstructionCombining.cpp
index d3648e2d05053a8c891242051bc3feb7ff77ad9b..f01dca7fab84dbb94e6277ccc6aee12f0f506fb5 100644 (file)
@@ -390,6 +390,25 @@ static bool RightDistributesOverLeft(Instruction::BinaryOps LOp,
                                      Instruction::BinaryOps ROp) {
   if (Instruction::isCommutative(ROp))
     return LeftDistributesOverRight(ROp, LOp);
+
+  switch (LOp) {
+  default:
+    return false;
+  // (X >> Z) & (Y >> Z)  -> (X&Y) >> Z  for all shifts.
+  // (X >> Z) | (Y >> Z)  -> (X|Y) >> Z  for all shifts.
+  // (X >> Z) ^ (Y >> Z)  -> (X^Y) >> Z  for all shifts.
+  case Instruction::And:
+  case Instruction::Or:
+  case Instruction::Xor:
+    switch (ROp) {
+    default:
+      return false;
+    case Instruction::Shl:
+    case Instruction::LShr:
+    case Instruction::AShr:
+      return true;
+    }
+  }
   // TODO: It would be nice to handle division, aka "(X + Y)/Z = X/Z + Y/Z",
   // but this requires knowing that the addition does not overflow and other
   // such subtleties.
@@ -411,26 +430,37 @@ static Value *getIdentityValue(Instruction::BinaryOps OpCode, Value *V) {
 }
 
 /// This function factors binary ops which can be combined using distributive
-/// laws. This also factor SHL as MUL e.g. SHL(X, 2) ==> MUL(X, 4).
+/// laws. This function tries to transform 'Op' based TopLevelOpcode to enable
+/// factorization e.g for ADD(SHL(X , 2), MUL(X, 5)), When this function called
+/// with TopLevelOpcode == Instruction::Add and Op = SHL(X, 2), transforms
+/// SHL(X, 2) to MUL(X, 4) i.e. returns Instruction::Mul with LHS set to 'X' and
+/// RHS to 4.
 static Instruction::BinaryOps
-getBinOpsForFactorization(BinaryOperator *Op, Value *&LHS, Value *&RHS) {
+getBinOpsForFactorization(Instruction::BinaryOps TopLevelOpcode,
+                          BinaryOperator *Op, Value *&LHS, Value *&RHS) {
   if (!Op)
     return Instruction::BinaryOpsEnd;
 
-  if (Op->getOpcode() == Instruction::Shl) {
-    if (Constant *CST = dyn_cast<Constant>(Op->getOperand(1))) {
-      // The multiplier is really 1 << CST.
-      RHS = ConstantExpr::getShl(ConstantInt::get(Op->getType(), 1), CST);
-      LHS = Op->getOperand(0);
-      return Instruction::Mul;
+  LHS = Op->getOperand(0);
+  RHS = Op->getOperand(1);
+
+  switch (TopLevelOpcode) {
+  default:
+    return Op->getOpcode();
+
+  case Instruction::Add:
+  case Instruction::Sub:
+    if (Op->getOpcode() == Instruction::Shl) {
+      if (Constant *CST = dyn_cast<Constant>(Op->getOperand(1))) {
+        // The multiplier is really 1 << CST.
+        RHS = ConstantExpr::getShl(ConstantInt::get(Op->getType(), 1), CST);
+        return Instruction::Mul;
+      }
     }
+    return Op->getOpcode();
   }
 
   // TODO: We can add other conversions e.g. shr => div etc.
-
-  LHS = Op->getOperand(0);
-  RHS = Op->getOperand(1);
-  return Op->getOpcode();
 }
 
 /// This tries to simplify binary operations by factorizing out common terms
@@ -529,8 +559,9 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) {
 
   // Factorization.
   Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
-  Instruction::BinaryOps LHSOpcode = getBinOpsForFactorization(Op0, A, B);
-  Instruction::BinaryOps RHSOpcode = getBinOpsForFactorization(Op1, C, D);
+  auto TopLevelOpcode = I.getOpcode();
+  auto LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B);
+  auto RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D);
 
   // The instruction has the form "(A op' B) op (C op' D)".  Try to factorize
   // a common term.
@@ -552,7 +583,6 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) {
     return V;
 
   // Expansion.
-  Instruction::BinaryOps TopLevelOpcode = I.getOpcode();
   if (Op0 && RightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) {
     // The instruction has the form "(A op' B) op C".  See if expanding it out
     // to "(A op C) op' (B op C)" results in simplifications.
@@ -1667,6 +1697,18 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
   if (!DL)
     return nullptr;
 
+  // addrspacecast between types is canonicalized as a bitcast, then an
+  // addrspacecast. To take advantage of the below bitcast + struct GEP, look
+  // through the addrspacecast.
+  if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(PtrOp)) {
+    //   X = bitcast A addrspace(1)* to B addrspace(1)*
+    //   Y = addrspacecast A addrspace(1)* to B addrspace(2)*
+    //   Z = gep Y, <...constant indices...>
+    // Into an addrspacecasted GEP of the struct.
+    if (BitCastInst *BC = dyn_cast<BitCastInst>(ASC->getOperand(0)))
+      PtrOp = BC;
+  }
+
   /// See if we can simplify:
   ///   X = bitcast A* to B*
   ///   Y = gep X, <...constant indices...>
@@ -1675,11 +1717,10 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
   if (BitCastInst *BCI = dyn_cast<BitCastInst>(PtrOp)) {
     Value *Operand = BCI->getOperand(0);
     PointerType *OpType = cast<PointerType>(Operand->getType());
-    unsigned OffsetBits = DL->getPointerTypeSizeInBits(OpType);
+    unsigned OffsetBits = DL->getPointerTypeSizeInBits(GEP.getType());
     APInt Offset(OffsetBits, 0);
     if (!isa<BitCastInst>(Operand) &&
-        GEP.accumulateConstantOffset(*DL, Offset) &&
-        StrippedPtrTy->getAddressSpace() == GEP.getPointerAddressSpace()) {
+        GEP.accumulateConstantOffset(*DL, Offset)) {
 
       // If this GEP instruction doesn't move the pointer, just replace the GEP
       // with a bitcast of the real input to the dest type.
@@ -1697,6 +1738,9 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
             return &GEP;
           }
         }
+
+        if (Operand->getType()->getPointerAddressSpace() != GEP.getAddressSpace())
+          return new AddrSpaceCastInst(Operand, GEP.getType());
         return new BitCastInst(Operand, GEP.getType());
       }
 
@@ -1712,6 +1756,9 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
         if (NGEP->getType() == GEP.getType())
           return ReplaceInstUsesWith(GEP, NGEP);
         NGEP->takeName(&GEP);
+
+        if (NGEP->getType()->getPointerAddressSpace() != GEP.getAddressSpace())
+          return new AddrSpaceCastInst(NGEP, GEP.getType());
         return new BitCastInst(NGEP, GEP.getType());
       }
     }
@@ -2531,7 +2578,7 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) {
 /// whose condition is a known constant, we only visit the reachable successors.
 ///
 static bool AddReachableCodeToWorklist(BasicBlock *BB,
-                                       SmallPtrSet<BasicBlock*, 64> &Visited,
+                                       SmallPtrSetImpl<BasicBlock*> &Visited,
                                        InstCombiner &IC,
                                        const DataLayout *DL,
                                        const TargetLibraryInfo *TLI) {