Fix a DAG combine bug visitBRCOND() is transforming br(xor(x, y)) to br(x != y).
[oota-llvm.git] / lib / CodeGen / SelectionDAG / DAGCombiner.cpp
index 37d7731aa15898dbfa54295bc9b83342caf780a5..359c4cf8e4cba0ffdfd8d1cb4bdf3854ad2e6dfc 100644 (file)
 
 #define DEBUG_TYPE "dagcombine"
 #include "llvm/CodeGen/SelectionDAG.h"
-#include "llvm/DerivedTypes.h"
-#include "llvm/LLVMContext.h"
-#include "llvm/CodeGen/MachineFunction.h"
-#include "llvm/CodeGen/MachineFrameInfo.h"
-#include "llvm/Analysis/AliasAnalysis.h"
-#include "llvm/DataLayout.h"
-#include "llvm/Target/TargetLowering.h"
-#include "llvm/Target/TargetMachine.h"
-#include "llvm/Target/TargetOptions.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/CodeGen/MachineFrameInfo.h"
+#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/LLVMContext.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/Target/TargetLowering.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/Target/TargetOptions.h"
 #include <algorithm>
 using namespace llvm;
 
@@ -291,6 +292,10 @@ namespace {
                  unsigned SrcValueAlign2,
                  const MDNode *TBAAInfo2) const;
 
+    /// isAlias - Return true if there is any possibility that the two addresses
+    /// overlap.
+    bool isAlias(LSBaseSDNode *Op0, LSBaseSDNode *Op1);
+
     /// FindAliasInfo - Extracts the relevant alias information from the memory
     /// node.  Returns true if the operand was a load.
     bool FindAliasInfo(SDNode *N,
@@ -1178,7 +1183,7 @@ SDValue DAGCombiner::combine(SDNode *N) {
 
       // Expose the DAG combiner to the target combiner impls.
       TargetLowering::DAGCombinerInfo
-        DagCombineInfo(DAG, !LegalTypes, !LegalOperations, false, this);
+        DagCombineInfo(DAG, Level, false, this);
 
       RV = TLI.PerformDAGCombine(N, DagCombineInfo);
     }
@@ -1377,6 +1382,12 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
   if (VT.isVector()) {
     SDValue FoldedVOp = SimplifyVBinOp(N);
     if (FoldedVOp.getNode()) return FoldedVOp;
+
+    // fold (add x, 0) -> x, vector edition
+    if (ISD::isBuildVectorAllZeros(N1.getNode()))
+      return N0;
+    if (ISD::isBuildVectorAllZeros(N0.getNode()))
+      return N1;
   }
 
   // fold (add x, undef) -> undef
@@ -1620,6 +1631,10 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
   if (VT.isVector()) {
     SDValue FoldedVOp = SimplifyVBinOp(N);
     if (FoldedVOp.getNode()) return FoldedVOp;
+
+    // fold (sub x, 0) -> x, vector edition
+    if (ISD::isBuildVectorAllZeros(N1.getNode()))
+      return N0;
   }
 
   // fold (sub x, x) -> 0
@@ -2423,6 +2438,18 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
   if (VT.isVector()) {
     SDValue FoldedVOp = SimplifyVBinOp(N);
     if (FoldedVOp.getNode()) return FoldedVOp;
+
+    // fold (and x, 0) -> 0, vector edition
+    if (ISD::isBuildVectorAllZeros(N0.getNode()))
+      return N0;
+    if (ISD::isBuildVectorAllZeros(N1.getNode()))
+      return N1;
+
+    // fold (and x, -1) -> x, vector edition
+    if (ISD::isBuildVectorAllOnes(N0.getNode()))
+      return N1;
+    if (ISD::isBuildVectorAllOnes(N1.getNode()))
+      return N0;
   }
 
   // fold (and x, undef) -> 0
@@ -2606,7 +2633,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
       bool isInteger = LL.getValueType().isInteger();
       ISD::CondCode Result = ISD::getSetCCAndOperation(Op0, Op1, isInteger);
       if (Result != ISD::SETCC_INVALID &&
-          (!LegalOperations || TLI.isCondCodeLegal(Result, LL.getValueType())))
+          (!LegalOperations ||
+           TLI.isCondCodeLegal(Result, LL.getSimpleValueType())))
         return DAG.getSetCC(N->getDebugLoc(), N0.getValueType(),
                             LL, LR, Result);
     }
@@ -2766,7 +2794,6 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
       }
     }
   }
-      
 
   return SDValue();
 }
@@ -2959,7 +2986,8 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
   SDValue N00 = N0.getOperand(0);
   SDValue N01 = N0.getOperand(1);
 
-  if (N1.getOpcode() == ISD::OR) {
+  if (N1.getOpcode() == ISD::OR &&
+      N00.getNumOperands() == 2 && N01.getNumOperands() == 2) {
     // (or (or (and), (and)), (or (and), (and)))
     SDValue N000 = N00.getOperand(0);
     if (!isBSwapHWordElement(N000, Parts))
@@ -3021,6 +3049,18 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
   if (VT.isVector()) {
     SDValue FoldedVOp = SimplifyVBinOp(N);
     if (FoldedVOp.getNode()) return FoldedVOp;
+
+    // fold (or x, 0) -> x, vector edition
+    if (ISD::isBuildVectorAllZeros(N0.getNode()))
+      return N1;
+    if (ISD::isBuildVectorAllZeros(N1.getNode()))
+      return N0;
+
+    // fold (or x, -1) -> -1, vector edition
+    if (ISD::isBuildVectorAllOnes(N0.getNode()))
+      return N0;
+    if (ISD::isBuildVectorAllOnes(N1.getNode()))
+      return N1;
   }
 
   // fold (or x, undef) -> -1
@@ -3103,7 +3143,8 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
       bool isInteger = LL.getValueType().isInteger();
       ISD::CondCode Result = ISD::getSetCCOrOperation(Op0, Op1, isInteger);
       if (Result != ISD::SETCC_INVALID &&
-          (!LegalOperations || TLI.isCondCodeLegal(Result, LL.getValueType())))
+          (!LegalOperations ||
+           TLI.isCondCodeLegal(Result, LL.getSimpleValueType())))
         return DAG.getSetCC(N->getDebugLoc(), N0.getValueType(),
                             LL, LR, Result);
     }
@@ -3330,6 +3371,12 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
   if (VT.isVector()) {
     SDValue FoldedVOp = SimplifyVBinOp(N);
     if (FoldedVOp.getNode()) return FoldedVOp;
+
+    // fold (xor x, 0) -> x, vector edition
+    if (ISD::isBuildVectorAllZeros(N0.getNode()))
+      return N1;
+    if (ISD::isBuildVectorAllZeros(N1.getNode()))
+      return N0;
   }
 
   // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
@@ -3360,7 +3407,8 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
     ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(),
                                                isInt);
 
-    if (!LegalOperations || TLI.isCondCodeLegal(NotCC, LHS.getValueType())) {
+    if (!LegalOperations ||
+        TLI.isCondCodeLegal(NotCC, LHS.getSimpleValueType())) {
       switch (N0.getOpcode()) {
       default:
         llvm_unreachable("Unhandled SetCC Equivalent!");
@@ -5025,11 +5073,15 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
       // At this point, we must have a load or else we can't do the transform.
       if (!isa<LoadSDNode>(N0)) return SDValue();
 
+      // Because a SRL must be assumed to *need* to zero-extend the high bits
+      // (as opposed to anyext the high bits), we can't combine the zextload
+      // lowering of SRL and an sextload.
+      if (cast<LoadSDNode>(N0)->getExtensionType() == ISD::SEXTLOAD)
+        return SDValue();
+
       // If the shift amount is larger than the input type then we're not
       // accessing any of the loaded bytes.  If the load was a zextload/extload
       // then the result of the shift+trunc is zero/undef (handled elsewhere).
-      // If the load was a sextload then the result is a splat of the sign bit
-      // of the extended byte.  This is not worth optimizing for.
       if (ShAmt >= cast<LoadSDNode>(N0)->getMemoryVT().getSizeInBits())
         return SDValue();
     }
@@ -5187,6 +5239,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
                                      LN0->getAlignment());
     CombineTo(N, ExtLoad);
     CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
+    AddToWorkList(ExtLoad.getNode());
     return SDValue(N, 0);   // Return N so it doesn't get rechecked!
   }
   // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use
@@ -6682,18 +6735,24 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) {
     if (Op0.getOpcode() == Op1.getOpcode()) {
       // Avoid missing important xor optimizations.
       SDValue Tmp = visitXOR(TheXor);
-      if (Tmp.getNode() && Tmp.getNode() != TheXor) {
-        DEBUG(dbgs() << "\nReplacing.8 ";
-              TheXor->dump(&DAG);
-              dbgs() << "\nWith: ";
-              Tmp.getNode()->dump(&DAG);
-              dbgs() << '\n');
-        WorkListRemover DeadNodes(*this);
-        DAG.ReplaceAllUsesOfValueWith(N1, Tmp);
-        removeFromWorkList(TheXor);
-        DAG.DeleteNode(TheXor);
-        return DAG.getNode(ISD::BRCOND, N->getDebugLoc(),
-                           MVT::Other, Chain, Tmp, N2);
+      if (Tmp.getNode()) {
+        if (Tmp.getNode() != TheXor) {
+          DEBUG(dbgs() << "\nReplacing.8 ";
+                TheXor->dump(&DAG);
+                dbgs() << "\nWith: ";
+                Tmp.getNode()->dump(&DAG);
+                dbgs() << '\n');
+          WorkListRemover DeadNodes(*this);
+          DAG.ReplaceAllUsesOfValueWith(N1, Tmp);
+          removeFromWorkList(TheXor);
+          DAG.DeleteNode(TheXor);
+          return DAG.getNode(ISD::BRCOND, N->getDebugLoc(),
+                             MVT::Other, Chain, Tmp, N2);
+        }
+
+        // visitXOR has changed XOR's operands.
+        Op0 = TheXor->getOperand(0);
+        Op1 = TheXor->getOperand(1);
       }
     }
 
@@ -6772,7 +6831,7 @@ static bool canFoldInAddressingMode(SDNode *N, SDNode *Use,
   } else
     return false;
 
-  AddrMode AM;
+  TargetLowering::AddrMode AM;
   if (N->getOpcode() == ISD::ADD) {
     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
     if (Offset)
@@ -7386,7 +7445,8 @@ SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
     // start at the previous one.
     if (ShAmt % NewBW)
       ShAmt = (((ShAmt + NewBW - 1) / NewBW) * NewBW) - NewBW;
-    APInt Mask = APInt::getBitsSet(BitWidth, ShAmt, ShAmt + NewBW);
+    APInt Mask = APInt::getBitsSet(BitWidth, ShAmt,
+                                   std::min(BitWidth, ShAmt + NewBW));
     if ((Imm & Mask) == Imm) {
       APInt NewImm = (Imm & Mask).lshr(ShAmt).trunc(NewBW);
       if (Opc == ISD::AND)
@@ -7552,7 +7612,14 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) {
   if (BasePtr.first.getOpcode() == ISD::UNDEF)
     return false;
 
+  // Save the LoadSDNodes that we find in the chain.
+  // We need to make sure that these nodes do not interfere with
+  // any of the store nodes.
+  SmallVector<LSBaseSDNode*, 8> AliasLoadNodes;
+
+  // Save the StoreSDNodes that we find in the chain.
   SmallVector<MemOpLink, 8> StoreNodes;
+
   // Walk up the chain and look for nodes with offsets from the same
   // base pointer. Stop when reaching an instruction with a different kind
   // or instruction which has a different base pointer.
@@ -7596,8 +7663,26 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) {
     // We found a potential memory operand to merge.
     StoreNodes.push_back(MemOpLink(Index, Ptr.second, Seq++));
 
-    // Move up the chain to the next memory operation.
-    Index = dyn_cast<StoreSDNode>(Index->getChain().getNode());
+    // Find the next memory operand in the chain. If the next operand in the
+    // chain is a store then move up and continue the scan with the next
+    // memory operand. If the next operand is a load save it and use alias
+    // information to check if it interferes with anything.
+    SDNode *NextInChain = Index->getChain().getNode();
+    while (1) {
+      if (StoreSDNode *STn = dyn_cast<StoreSDNode>(NextInChain)) {
+        // We found a store node. Use it for the next iteration.
+        Index = STn;
+        break;
+      } else if (LoadSDNode *Ldn = dyn_cast<LoadSDNode>(NextInChain)) {
+        // Save the load node for later. Continue the scan.
+        AliasLoadNodes.push_back(Ldn);
+        NextInChain = Ldn->getChain().getNode();
+        continue;
+      } else {
+        Index = NULL;
+        break;
+      }
+    }
   }
 
   // Check if there is anything to merge.
@@ -7612,9 +7697,25 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) {
   // store memory address.
   unsigned LastConsecutiveStore = 0;
   int64_t StartAddress = StoreNodes[0].OffsetFromBase;
-  for (unsigned i=1; i<StoreNodes.size(); ++i) {
-    int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
-    if (CurrAddress - StartAddress != (ElementSizeBytes * i))
+  for (unsigned i = 0, e = StoreNodes.size(); i < e; ++i) {
+
+    // Check that the addresses are consecutive starting from the second
+    // element in the list of stores.
+    if (i > 0) {
+      int64_t CurrAddress = StoreNodes[i].OffsetFromBase;
+      if (CurrAddress - StartAddress != (ElementSizeBytes * i))
+        break;
+    }
+
+    bool Alias = false;
+    // Check if this store interferes with any of the loads that we found.
+    for (unsigned ld = 0, lde = AliasLoadNodes.size(); ld < lde; ++ld)
+      if (isAlias(AliasLoadNodes[ld], StoreNodes[i].MemNode)) {
+        Alias = true;
+        break;
+      }
+    // We found a load that alias with this store. Stop the sequence.
+    if (Alias)
       break;
 
     // Mark this node as useful.
@@ -7654,8 +7755,11 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) {
         LastLegalVectorType = i + 1;
     }
 
-    // We only use vectors if the constant is known to be zero.
-    if (NonZero)
+    // We only use vectors if the constant is known to be zero and the
+    // function is not marked with the noimplicitfloat attribute.
+    if (NonZero || (DAG.getMachineFunction().getFunction()->getAttributes().
+                    hasAttribute(AttributeSet::FunctionIndex,
+                                 Attribute::NoImplicitFloat)))
       LastLegalVectorType = 0;
 
     // Check if we found a legal integer type to store.
@@ -8116,8 +8220,21 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
 
   // Only perform this optimization before the types are legal, because we
   // don't want to perform this optimization on every DAGCombine invocation.
-  if (!LegalTypes && MergeConsecutiveStores(ST))
-    return SDValue(N, 0);
+  if (!LegalTypes) {
+    bool EverChanged = false;
+
+    do {
+      // There can be multiple store sequences on the same chain.
+      // Keep trying to merge store sequences until we are unable to do so
+      // or until we merge the last store on the chain.
+      bool Changed = MergeConsecutiveStores(ST);
+      EverChanged |= Changed;
+      if (!Changed) break;
+    } while (ST->getOpcode() != ISD::DELETED_NODE);
+
+    if (EverChanged)
+      return SDValue(N, 0);
+  }
 
   return ReduceLoadOpStoreWidth(N);
 }
@@ -8514,11 +8631,8 @@ SDValue DAGCombiner::reduceBuildVecConvertToConvertBuildVec(SDNode *N) {
     if (Opcode == ISD::DELETED_NODE &&
         (Opc == ISD::UINT_TO_FP || Opc == ISD::SINT_TO_FP)) {
       Opcode = Opc;
-      // If not supported by target, bail out.
-      if (TLI.getOperationAction(Opcode, VT) != TargetLowering::Legal &&
-          TLI.getOperationAction(Opcode, VT) != TargetLowering::Custom)
-        return SDValue();
     }
+
     if (Opc != Opcode)
       return SDValue();
 
@@ -8543,6 +8657,10 @@ SDValue DAGCombiner::reduceBuildVecConvertToConvertBuildVec(SDNode *N) {
   assert(SrcVT != MVT::Other && "Cannot determine source type!");
 
   EVT NVT = EVT::getVectorVT(*DAG.getContext(), SrcVT, NumInScalars);
+
+  if (!TLI.isOperationLegalOrCustom(Opcode, NVT))
+    return SDValue();
+
   SmallVector<SDValue, 8> Opnds;
   for (unsigned i = 0; i != NumInScalars; ++i) {
     SDValue In = N->getOperand(i);
@@ -9537,7 +9655,7 @@ SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0,
                                    SDValue N1, ISD::CondCode Cond,
                                    DebugLoc DL, bool foldBooleans) {
   TargetLowering::DAGCombinerInfo
-    DagCombineInfo(DAG, !LegalTypes, !LegalOperations, false, this);
+    DagCombineInfo(DAG, Level, false, this);
   return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL);
 }
 
@@ -9680,6 +9798,23 @@ bool DAGCombiner::isAlias(SDValue Ptr1, int64_t Size1,
   return true;
 }
 
+bool DAGCombiner::isAlias(LSBaseSDNode *Op0, LSBaseSDNode *Op1) {
+  SDValue Ptr0, Ptr1;
+  int64_t Size0, Size1;
+  const Value *SrcValue0, *SrcValue1;
+  int SrcValueOffset0, SrcValueOffset1;
+  unsigned SrcValueAlign0, SrcValueAlign1;
+  const MDNode *SrcTBAAInfo0, *SrcTBAAInfo1;
+  FindAliasInfo(Op0, Ptr0, Size0, SrcValue0, SrcValueOffset0,
+                SrcValueAlign0, SrcTBAAInfo0);
+  FindAliasInfo(Op1, Ptr1, Size1, SrcValue1, SrcValueOffset1,
+                SrcValueAlign1, SrcTBAAInfo1);
+  return isAlias(Ptr0, Size0, SrcValue0, SrcValueOffset0,
+                 SrcValueAlign0, SrcTBAAInfo0,
+                 Ptr1, Size1, SrcValue1, SrcValueOffset1,
+                 SrcValueAlign1, SrcTBAAInfo1);
+}
+
 /// FindAliasInfo - Extracts the relevant alias information from the memory
 /// node.  Returns true if the operand was a load.
 bool DAGCombiner::FindAliasInfo(SDNode *N,